Unverified Commit 38af4f68 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix grammar abort & Minor style fixes (#7204)

parent a6305c7d
...@@ -15,7 +15,6 @@ from functools import partial ...@@ -15,7 +15,6 @@ from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Union
import torch import torch
import triton
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1": if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import logging import logging
...@@ -33,7 +32,7 @@ from sglang.srt.layers.utils import is_sm100_supported ...@@ -33,7 +32,7 @@ from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -756,7 +755,7 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -756,7 +755,7 @@ class FlashInferMLAMultiStepDraftBackend:
if topk > 1: if topk > 1:
raise ValueError( raise ValueError(
f"Currently Flashinfer MLA only supports topk=1 for speculative decoding" "Currently Flashinfer MLA only supports topk=1 for speculative decoding"
) )
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
...@@ -815,9 +814,9 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -815,9 +814,9 @@ class FlashInferMLAMultiStepDraftBackend:
self.pool_len, self.pool_len,
kv_indices_buffer.shape[1], kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs), next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps), next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs), next_power_of_2(bs),
) )
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
......
...@@ -464,11 +464,9 @@ class FlashMLAMultiStepDraftBackend: ...@@ -464,11 +464,9 @@ class FlashMLAMultiStepDraftBackend:
topk: int, topk: int,
speculative_num_steps: int, speculative_num_steps: int,
): ):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
if topk > 1: if topk > 1:
raise ValueError( raise ValueError(
f"Currently FlashMLA only supports topk=1 for speculative decoding" "Currently FlashMLA only supports topk=1 for speculative decoding"
) )
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
......
...@@ -12,7 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito ...@@ -12,7 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var, get_device_core_count from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -766,6 +766,7 @@ class TritonMultiStepDraftBackend: ...@@ -766,6 +766,7 @@ class TritonMultiStepDraftBackend:
self.device = model_runner.device self.device = model_runner.device
# Cached variables for generate_draft_decode_kv_indices # Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.page_size = model_runner.server_args.page_size
def common_template( def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
...@@ -788,9 +789,9 @@ class TritonMultiStepDraftBackend: ...@@ -788,9 +789,9 @@ class TritonMultiStepDraftBackend:
self.pool_len, self.pool_len,
kv_indices_buffer.shape[1], kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs), next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps), next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs), next_power_of_2(bs),
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
......
...@@ -708,7 +708,7 @@ def decode_attention_fwd( ...@@ -708,7 +708,7 @@ def decode_attention_fwd(
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap=logit_cap,
) )
else: else:
# GQA/MQA/MLA # GQA/MQA/MLA
...@@ -724,5 +724,5 @@ def decode_attention_fwd( ...@@ -724,5 +724,5 @@ def decode_attention_fwd(
num_kv_splits, num_kv_splits,
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap=logit_cap,
) )
...@@ -18,7 +18,6 @@ from typing import Optional ...@@ -18,7 +18,6 @@ from typing import Optional
from torch import nn from torch import nn
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -52,9 +51,9 @@ class RadixAttention(nn.Module): ...@@ -52,9 +51,9 @@ class RadixAttention(nn.Module):
sliding_window_size: int = -1, sliding_window_size: int = -1,
is_cross_attention: bool = False, is_cross_attention: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
attn_type=AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
prefix: str = "",
use_irope: bool = False, use_irope: bool = False,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
......
...@@ -2108,7 +2108,8 @@ class Scheduler( ...@@ -2108,7 +2108,8 @@ class Scheduler(
# In this case, we change the input_ids to be only one token to make this prefill cheap. # In this case, we change the input_ids to be only one token to make this prefill cheap.
if req.rid.startswith(recv_req.rid): if req.rid.startswith(recv_req.rid):
logger.debug(f"Abort grammar queue request. {req.rid=}") logger.debug(f"Abort grammar queue request. {req.rid=}")
req.grammar.cancel() if req.grammar:
req.grammar.cancel()
req.set_finish_with_abort("Aborted by AbortReq.") req.set_finish_with_abort("Aborted by AbortReq.")
# Delete requests in the running batch # Delete requests in the running batch
......
...@@ -141,15 +141,12 @@ class KVCache(abc.ABC): ...@@ -141,15 +141,12 @@ class KVCache(abc.ABC):
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices): def get_flat_data(self, indices):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data): def transfer(self, indices, flat_data):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def transfer_per_layer(self, indices, flat_data, layer_id): def transfer_per_layer(self, indices, flat_data, layer_id):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -86,8 +86,8 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -86,8 +86,8 @@ class EAGLEDraftExtendCudaGraphRunner:
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.accept_length = ( self.accept_length = torch.full(
torch.ones((self.max_bs,), dtype=torch.int32) * self.num_tokens_per_bs (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
) )
# Capture # Capture
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment