"vscode:/vscode.git/clone" did not exist on "7947fc8fb38b1d3a2aca03f22a2e6a3caa63f2a0"
Unverified Commit de2dd738 authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

Revert "[feature] Rework Ascend NPU graph support" (#9385)

parent 1ec97697
...@@ -9,7 +9,7 @@ from transformers import AutoConfig ...@@ -9,7 +9,7 @@ from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_triton, fused_moe as fused_moe_triton,
) )
from sglang.srt.model_executor.graph_runner import set_torch_compile_config from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
def get_model_config(model_name: str, tp_size: int): def get_model_config(model_name: str, tp_size: int):
......
...@@ -55,7 +55,7 @@ _is_npu = is_npu() ...@@ -55,7 +55,7 @@ _is_npu = is_npu()
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream stream: torch.cuda.Stream
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
...@@ -252,13 +252,9 @@ class GroupCoordinator: ...@@ -252,13 +252,9 @@ class GroupCoordinator:
if is_cuda_alike(): if is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}") self.device = torch.device(f"cuda:{local_rank}")
elif _is_npu:
self.device = torch.device(f"npu:{local_rank}")
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.device_module = torch.get_device_module(self.device)
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.use_pymscclpp = use_pymscclpp self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
...@@ -406,7 +402,7 @@ class GroupCoordinator: ...@@ -406,7 +402,7 @@ class GroupCoordinator:
self, graph_capture_context: Optional[GraphCaptureContext] = None self, graph_capture_context: Optional[GraphCaptureContext] = None
): ):
if graph_capture_context is None: if graph_capture_context is None:
stream = self.device_module.Stream() stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream) graph_capture_context = GraphCaptureContext(stream)
else: else:
stream = graph_capture_context.stream stream = graph_capture_context.stream
...@@ -417,11 +413,11 @@ class GroupCoordinator: ...@@ -417,11 +413,11 @@ class GroupCoordinator:
# ensure all initialization operations complete before attempting to # ensure all initialization operations complete before attempting to
# capture the graph on another stream # capture the graph on another stream
curr_stream = self.device_module.current_stream() curr_stream = torch.cuda.current_stream()
if curr_stream != stream: if curr_stream != stream:
stream.wait_stream(curr_stream) stream.wait_stream(curr_stream)
with self.device_module.stream(stream), maybe_ca_context: with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective # In graph mode, we have to be very careful about the collective
# operations. The current status is: # operations. The current status is:
# allreduce \ Mode | Eager | Graph | # allreduce \ Mode | Eager | Graph |
...@@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ...@@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
) )
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache() torch.xpu.empty_cache()
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
......
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
import torch_npu import torch_npu
...@@ -27,7 +27,6 @@ class ForwardMetadata: ...@@ -27,7 +27,6 @@ class ForwardMetadata:
# seq len inputs # seq len inputs
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_int: Optional[torch.Tensor] = None seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_list: Optional[List[int]] = None
class AscendAttnBackend(AttentionBackend): class AscendAttnBackend(AttentionBackend):
...@@ -52,7 +51,7 @@ class AscendAttnBackend(AttentionBackend): ...@@ -52,7 +51,7 @@ class AscendAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner): def __init__(self, model_runner: ModelRunner):
super().__init__() super().__init__()
self.forward_metadata = None self.forward_metadata = ForwardMetadata()
self.device = model_runner.device self.device = model_runner.device
self.gen_attention_mask(128, model_runner.dtype) self.gen_attention_mask(128, model_runner.dtype)
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
...@@ -61,15 +60,9 @@ class AscendAttnBackend(AttentionBackend): ...@@ -61,15 +60,9 @@ class AscendAttnBackend(AttentionBackend):
self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.native_attn = TorchNativeAttnBackend(model_runner) self.native_attn = TorchNativeAttnBackend(model_runner)
self.graph_metadata = {}
self.max_context_len = model_runner.model_config.context_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.graph_mode = False
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
self.forward_metadata = ForwardMetadata()
self.forward_metadata.block_tables = ( self.forward_metadata.block_tables = (
forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : forward_batch.seq_lens.max() forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
...@@ -82,63 +75,6 @@ class AscendAttnBackend(AttentionBackend): ...@@ -82,63 +75,6 @@ class AscendAttnBackend(AttentionBackend):
) )
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
self.graph_mode = False
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.graph_metadata = {
"block_tables": torch.empty(
(max_bs, self.max_context_len // self.page_size),
dtype=torch.int32,
device=self.device,
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
metadata = ForwardMetadata()
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
self.graph_metadata[bs] = metadata
self.forward_metadata = metadata
self.graph_mode = True
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
metadata = self.graph_metadata[bs]
max_len = seq_lens_cpu[:bs].max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.block_tables[:bs, :max_seq_pages].copy_(
self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size]
// self.page_size
)
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
metadata.block_tables[bs:, :].fill_(0)
self.forward_metadata = metadata
self.graph_mode = True
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 return 1
...@@ -231,74 +167,28 @@ class AscendAttnBackend(AttentionBackend): ...@@ -231,74 +167,28 @@ class AscendAttnBackend(AttentionBackend):
layer, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
) )
if not self.use_mla: if not self.use_mla:
if self.graph_mode: k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer( v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
layer.layer_id
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
num_tokens = query.shape[0]
workspace = (
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
scale=layer.scaling,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
)
)
output = torch.empty(
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
dtype=q.dtype,
device=q.device,
)
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
torch_npu.npu_fused_infer_attention_score.out(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
scale=layer.scaling,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
workspace=workspace,
out=[output, softmax_lse],
)
else:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id
)
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
num_tokens = query.shape[0] num_tokens = query.shape[0]
output = torch.empty( output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim), (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype, dtype=query.dtype,
device=query.device, device=query.device,
) )
torch_npu._npu_paged_attention( torch_npu._npu_paged_attention(
query=query, query=query,
key_cache=k_cache, key_cache=k_cache,
value_cache=v_cache, value_cache=v_cache,
num_heads=layer.tp_q_head_num, num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num, num_kv_heads=layer.tp_k_head_num,
scale_value=layer.scaling, scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables, block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int, context_lens=self.forward_metadata.seq_lens_cpu_int,
out=output, out=output,
) )
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else: else:
query = q.view(-1, layer.tp_q_head_num, layer.head_dim) query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
......
...@@ -376,7 +376,7 @@ class MHATokenToKVPool(KVCache): ...@@ -376,7 +376,7 @@ class MHATokenToKVPool(KVCache):
v_scale: Optional[float] = None, v_scale: Optional[float] = None,
layer_id_override: Optional[int] = None, layer_id_override: Optional[int] = None,
): ):
from sglang.srt.model_executor.graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if layer_id_override is not None: if layer_id_override is not None:
layer_id = layer_id_override layer_id = layer_id_override
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Run the model with device graph and torch.compile.""" """Run the model with cuda graph and torch.compile."""
from __future__ import annotations from __future__ import annotations
...@@ -221,7 +221,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -221,7 +221,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
return capture_bs, compile_bs return capture_bs, compile_bs
# Reuse this memory pool across all device graph runners. # Reuse this memory pool across all cuda graph runners.
global_graph_memory_pool = None global_graph_memory_pool = None
...@@ -234,14 +234,12 @@ def set_global_graph_memory_pool(val): ...@@ -234,14 +234,12 @@ def set_global_graph_memory_pool(val):
global_graph_memory_pool = val global_graph_memory_pool = val
class GraphRunner: class CudaGraphRunner:
"""A GraphRunner is a base class to run the forward pass of a model with device graph and torch.compile.""" """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def __init__(self, model_runner: ModelRunner): def __init__(self, model_runner: ModelRunner):
# Parse args # Parse args
self.model_runner = model_runner self.model_runner = model_runner
self.device = model_runner.device
self.device_module = torch.get_device_module(self.device)
self.graphs = {} self.graphs = {}
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
...@@ -267,7 +265,7 @@ class GraphRunner: ...@@ -267,7 +265,7 @@ class GraphRunner:
# Batch sizes to capture # Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
rank0_log(f"Capture graph bs {self.capture_bs}") rank0_log(f"Capture cuda graph bs {self.capture_bs}")
self.capture_forward_mode = ForwardMode.DECODE self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1 self.num_tokens_per_bs = 1
...@@ -307,15 +305,13 @@ class GraphRunner: ...@@ -307,15 +305,13 @@ class GraphRunner:
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
# Graph inputs # Graph inputs
with torch.device(self.device): with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full( self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
) )
self.out_cache_loc = torch.zeros( self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
(self.max_num_token,), dtype=self._cache_loc_dtype()
)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
...@@ -370,12 +366,12 @@ class GraphRunner: ...@@ -370,12 +366,12 @@ class GraphRunner:
* self.num_tokens_per_bs * self.num_tokens_per_bs
), ),
dtype=torch.bool, dtype=torch.bool,
device=self.device, device="cuda",
) )
self.next_token_logits_buffer = torch.zeros( self.next_token_logits_buffer = torch.zeros(
(self.max_num_token, self.model_runner.model_config.vocab_size), (self.max_num_token, self.model_runner.model_config.vocab_size),
dtype=torch.float, dtype=torch.float,
device=self.device, device="cuda",
) )
# Capture # Capture
...@@ -384,12 +380,9 @@ class GraphRunner: ...@@ -384,12 +380,9 @@ class GraphRunner:
self.capture() self.capture()
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture device graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}" f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
) )
def _cache_loc_dtype(self):
return torch.int64
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
cuda_graph_bs = ( cuda_graph_bs = (
...@@ -509,16 +502,8 @@ class GraphRunner: ...@@ -509,16 +502,8 @@ class GraphRunner:
) )
logger.info(log_message) logger.info(log_message)
def _capture_graph(self, graph, pool, stream, run_once_fn):
with self.device_module.graph(graph, pool=pool, stream=stream):
out = run_once_fn()
return out
def _create_device_graph(self):
pass
def capture_one_batch_size(self, bs: int, forward: Callable): def capture_one_batch_size(self, bs: int, forward: Callable):
graph = self._create_device_graph() graph = torch.cuda.CUDAGraph()
stream = self.stream stream = self.stream
num_tokens = bs * self.num_tokens_per_bs num_tokens = bs * self.num_tokens_per_bs
...@@ -658,17 +643,19 @@ class GraphRunner: ...@@ -658,17 +643,19 @@ class GraphRunner:
return logits_output_or_pp_proxy_tensors return logits_output_or_pp_proxy_tensors
for _ in range(2): for _ in range(2):
self.device_module.synchronize() torch.cuda.synchronize()
self.model_runner.tp_group.barrier() self.model_runner.tp_group.barrier()
run_once() run_once()
if get_global_graph_memory_pool() is None: if get_global_graph_memory_pool() is None:
set_global_graph_memory_pool(self.device_module.graph_pool_handle()) set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
# Set graph pool id globally to be able to use symmetric memory # Set graph pool id globally to be able to use symmetric memory
set_graph_pool_id(get_global_graph_memory_pool()) set_graph_pool_id(get_global_graph_memory_pool())
out = self._capture_graph( with torch.cuda.graph(
graph, get_global_graph_memory_pool(), stream, run_once graph, pool=get_global_graph_memory_pool(), stream=stream
) ):
out = run_once()
return graph, out return graph, out
...@@ -850,7 +837,7 @@ class GraphRunner: ...@@ -850,7 +837,7 @@ class GraphRunner:
return spec_info return spec_info
GRAPH_CAPTURE_FAILED_MSG = ( CUDA_GRAPH_CAPTURE_FAILED_MSG = (
"Possible solutions:\n" "Possible solutions:\n"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with cuda graph and torch.compile."""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.model_executor.graph_runner import GraphRunner
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
class CudaGraphRunner(GraphRunner):
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def __init__(self, model_runner: ModelRunner):
# Parse args
super().__init__(model_runner)
def _create_device_graph(self):
return torch.cuda.CUDAGraph()
...@@ -89,11 +89,8 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -89,11 +89,8 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool, ReqToTokenPool,
SWAKVPool, SWAKVPool,
) )
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
# TODO(iforgetmyname): Renaming on the way
from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
...@@ -344,12 +341,9 @@ class ModelRunner: ...@@ -344,12 +341,9 @@ class ModelRunner:
if self.device == "cuda": if self.device == "cuda":
self.init_cublas() self.init_cublas()
self.init_attention_backend() self.init_attention_backend()
self.init_device_graphs() self.init_cuda_graphs()
elif self.device == "npu":
self.init_attention_backend()
self.init_device_graphs()
else: else:
self.graph_runner = None self.cuda_graph_runner = None
self.cuda_graph_mem_usage = 0 self.cuda_graph_mem_usage = 0
self.init_attention_backend() self.init_attention_backend()
...@@ -923,8 +917,7 @@ class ModelRunner: ...@@ -923,8 +917,7 @@ class ModelRunner:
) )
# We need to get device after patch otherwise the device would be wrong # We need to get device after patch otherwise the device would be wrong
self.device_module = torch.get_device_module(self.device) infered_device = torch.cuda.current_device()
infered_device = self.device_module.current_device()
named_tensors = [ named_tensors = [
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device)) (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
...@@ -1592,9 +1585,9 @@ class ModelRunner: ...@@ -1592,9 +1585,9 @@ class ModelRunner:
.cuda() .cuda()
) )
def init_device_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
self.graph_runner = None self.cuda_graph_runner = None
self.cuda_graph_mem_usage = 0 self.cuda_graph_mem_usage = 0
if not self.is_generation: if not self.is_generation:
...@@ -1609,9 +1602,8 @@ class ModelRunner: ...@@ -1609,9 +1602,8 @@ class ModelRunner:
logger.info( logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
) )
self.graph_runner = ( self.cuda_graph_runner = CudaGraphRunner(self)
CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
)
after_mem = get_available_gpu_memory(self.device, self.gpu_id) after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.cuda_graph_mem_usage = before_mem - after_mem self.cuda_graph_mem_usage = before_mem - after_mem
logger.info( logger.info(
...@@ -1763,11 +1755,11 @@ class ModelRunner: ...@@ -1763,11 +1755,11 @@ class ModelRunner:
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool( can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph() forward_batch.forward_mode.is_cuda_graph()
and self.graph_runner and self.cuda_graph_runner
and self.graph_runner.can_run(forward_batch) and self.cuda_graph_runner.can_run(forward_batch)
) )
if can_run_cuda_graph: if can_run_cuda_graph:
ret = self.graph_runner.replay( ret = self.cuda_graph_runner.replay(
forward_batch, forward_batch,
skip_attn_backend_init=skip_attn_backend_init, skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with npu graph and torch.compile."""
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING
import torch
from sglang.srt.model_executor.graph_runner import GraphRunner
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
class NPUGraphRunner(GraphRunner):
"""A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile."""
def __init__(self, model_runner: ModelRunner):
super().__init__(model_runner)
def _create_device_graph(self):
return torch.npu.NPUGraph()
def _capture_graph(self, graph, pool, stream, run_once_fn):
with torch.npu.graph(
graph,
pool=pool,
stream=stream,
auto_dispatch_capture=True,
):
out = run_once_fn()
return out
def _update_inputs(self, seq_lens):
self.graphs[self.bs].update(
cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}]
)
def _cache_loc_dtype(self):
return torch.int32
def replay(
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
if not skip_attn_backend_init:
self.replay_prepare(forward_batch, pp_proxy_tensors)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
thread.start()
self.graphs[self.bs].replay()
thread.join()
output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
return LogitsProcessorOutput(
next_token_logits=output.next_token_logits[: self.raw_num_token],
hidden_states=(
output.hidden_states[: self.raw_num_token]
if output.hidden_states is not None
else None
),
)
else:
assert isinstance(output, PPProxyTensors)
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
...@@ -1200,7 +1200,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1200,7 +1200,7 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
): ):
from sglang.srt.model_executor.graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm: if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
......
...@@ -68,8 +68,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -68,8 +68,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import ( from sglang.srt.models.deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
......
...@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
from sglang.srt.model_executor.graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = ( batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
self._batch_image_inputs(forward_batch) self._batch_image_inputs(forward_batch)
......
...@@ -22,8 +22,8 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -22,8 +22,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
......
...@@ -52,8 +52,8 @@ from sglang.srt.layers.rotary_embedding import get_rope ...@@ -52,8 +52,8 @@ from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.graph_runner import get_is_capture_mode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
......
...@@ -6,22 +6,20 @@ from typing import TYPE_CHECKING, Callable ...@@ -6,22 +6,20 @@ from typing import TYPE_CHECKING, Callable
import torch import torch
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
# TODO(iforgetmyname): Renaming on the way CUDA_GRAPH_CAPTURE_FAILED_MSG,
from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner CudaGraphRunner,
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.model_executor.graph_runner import (
GRAPH_CAPTURE_FAILED_MSG,
get_batch_sizes_to_capture, get_batch_sizes_to_capture,
get_global_graph_memory_pool, get_global_graph_memory_pool,
model_capture_mode, model_capture_mode,
set_global_graph_memory_pool, set_global_graph_memory_pool,
set_torch_compile_config, set_torch_compile_config,
) )
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.utils import ( from sglang.srt.utils import (
require_attn_tp_gather, require_attn_tp_gather,
...@@ -123,7 +121,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -123,7 +121,7 @@ class EAGLEDraftCudaGraphRunner:
self.capture() self.capture()
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}" f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
) )
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
......
...@@ -6,16 +6,9 @@ from typing import TYPE_CHECKING, Callable ...@@ -6,16 +6,9 @@ from typing import TYPE_CHECKING, Callable
import torch import torch
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
# TODO(iforgetmyname): Renaming on the way CUDA_GRAPH_CAPTURE_FAILED_MSG,
from sglang.srt.model_executor.cuda_graph_runner_impl import CudaGraphRunner CudaGraphRunner,
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.model_executor.graph_runner import (
GRAPH_CAPTURE_FAILED_MSG,
LogitsProcessorOutput, LogitsProcessorOutput,
get_batch_sizes_to_capture, get_batch_sizes_to_capture,
get_global_graph_memory_pool, get_global_graph_memory_pool,
...@@ -23,6 +16,11 @@ from sglang.srt.model_executor.graph_runner import ( ...@@ -23,6 +16,11 @@ from sglang.srt.model_executor.graph_runner import (
set_global_graph_memory_pool, set_global_graph_memory_pool,
set_torch_compile_config, set_torch_compile_config,
) )
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.utils import ( from sglang.srt.utils import (
require_attn_tp_gather, require_attn_tp_gather,
...@@ -151,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -151,7 +149,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.capture() self.capture()
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n{GRAPH_CAPTURE_FAILED_MSG}" f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
) )
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
......
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"Qwen/Qwen2.5-7B-Instruct": {
"accuracy": 0.85,
"latency": 150,
"output_throughput": 30,
},
}
class TestAscendGraphTp1Bf16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
]
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"Qwen/Qwen2.5-7B-Instruct": {
"accuracy": 0.85,
"latency": 180,
"output_throughput": 20,
},
}
class TestAscendGraphTp2Bf16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--tp-size",
2,
]
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()
...@@ -269,11 +269,9 @@ suite_xeon = { ...@@ -269,11 +269,9 @@ suite_xeon = {
suite_ascend = { suite_ascend = {
"per-commit-1-ascend-npu": [ "per-commit-1-ascend-npu": [
TestFile("ascend/test_ascend_tp1_bf16.py", 400), TestFile("ascend/test_ascend_tp1_bf16.py", 400),
TestFile("ascend/test_ascend_graph_tp1_bf16.py", 400),
], ],
"per-commit-2-ascend-npu": [ "per-commit-2-ascend-npu": [
TestFile("ascend/test_ascend_tp2_bf16.py", 400), TestFile("ascend/test_ascend_tp2_bf16.py", 400),
TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400),
], ],
"per-commit-4-ascend-npu": [ "per-commit-4-ascend-npu": [
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400), TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment