Unverified Commit 9e1014cf authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "Add fast decode plan for flashinfer mla" (#4008)

parent fa561067
...@@ -186,5 +186,5 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -186,5 +186,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden. * `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models.
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on. * `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
...@@ -83,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be ...@@ -83,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. - **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). (In Experiment Stage)
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
......
...@@ -29,8 +29,9 @@ class AttentionBackend(ABC): ...@@ -29,8 +29,9 @@ class AttentionBackend(ABC):
num_tokens: int, num_tokens: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
**kwargs, spec_info: Optional[SpecInfo],
): ):
"""Init the metadata for a forward pass for capturing a cuda graph.""" """Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
...@@ -41,8 +42,9 @@ class AttentionBackend(ABC): ...@@ -41,8 +42,9 @@ class AttentionBackend(ABC):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
**kwargs, spec_info: Optional[SpecInfo],
): ):
"""Init the metadata for a forward pass for replying a cuda graph.""" """Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
......
...@@ -269,10 +269,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -269,10 +269,9 @@ class FlashInferAttnBackend(AttentionBackend):
num_tokens: int, num_tokens: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
**kwargs,
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
...@@ -340,10 +339,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -340,10 +339,9 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
**kwargs,
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update( self.indices_updater_decode.update(
......
...@@ -10,7 +10,6 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html ...@@ -10,7 +10,6 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
""" """
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
...@@ -28,12 +27,14 @@ from sglang.srt.utils import is_flashinfer_available ...@@ -28,12 +27,14 @@ from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import ( from flashinfer import (
BatchMLAPagedAttentionWrapper, BatchMLAPagedAttentionWrapper,
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.cascade import merge_state
@dataclass @dataclass
...@@ -62,7 +63,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -62,7 +63,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
# Parse constants # Parse constants
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
global_config.enable_flashinfer_mla = True global_config.enable_flashinfer_mla = True
...@@ -85,6 +85,10 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -85,6 +85,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
(max_bs + 1,), dtype=torch.int32, device=model_runner.device (max_bs + 1,), dtype=torch.int32, device=model_runner.device
) )
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.q_indptr_decode = torch.arange( self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
) )
...@@ -122,7 +126,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -122,7 +126,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
decode_wrapper=self.decode_wrapper, decode_wrapper=self.decode_wrapper,
init_metadata_replay=False,
) )
self.forward_metadata = DecodeMetadata(self.decode_wrapper) self.forward_metadata = DecodeMetadata(self.decode_wrapper)
else: else:
...@@ -158,20 +161,13 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -158,20 +161,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
cuda_graph_kv_indices = kv_indices_buf cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = cuda_graph_kv_indices self.cuda_graph_kv_indices = cuda_graph_kv_indices
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone() self.cuda_graph_custom_mask = torch.zeros(
self.cuda_graph_kv_indptr = self.kv_indptr.clone() (max_bs * self.max_context_len),
self.cuda_graph_kv_lens = torch.ones( dtype=torch.uint8,
(max_bs,), dtype=torch.int32, device=self.device device="cuda",
) )
self.cuda_graph_qk_indptr = self.kv_indptr.clone()
# For fast decode plan in graph replaying self.cuda_graph_qo_indptr = self.kv_indptr.clone()
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
self.fast_decode_kwargs = {
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
"kv_indices": self.cuda_graph_kv_indices,
}
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
...@@ -179,17 +175,18 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -179,17 +175,18 @@ class FlashInferMLAAttnBackend(AttentionBackend):
num_tokens: int, num_tokens: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
**kwargs, spec_info: Optional[SpecInfo],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper( decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, self.workspace_buffer,
use_cuda_graph=True, use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1], qo_indptr=self.qo_indptr[: num_tokens + 1],
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1], kv_indptr=self.kv_indptr[: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices, kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens], kv_len_arr=self.kv_last_page_len[:num_tokens],
backend="auto", backend="auto",
) )
...@@ -199,11 +196,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -199,11 +196,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens, seq_lens,
seq_lens_sum, seq_lens_sum,
decode_wrapper=decode_wrapper, decode_wrapper=decode_wrapper,
init_metadata_replay=False,
) )
self.decode_cuda_graph_metadata[bs] = decode_wrapper self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper) self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
else: else:
raise ValueError(f"Invalid mode: {forward_mode=}") raise ValueError(f"Invalid mode: {forward_mode=}")
...@@ -213,30 +208,16 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -213,30 +208,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
seq_lens_cpu: torch.Tensor, spec_info: Optional[SpecInfo],
**kwargs,
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
kv_len_arr_cpu = seq_lens_cpu[:bs]
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0
)
self.fast_decode_kwargs.update(
{
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
"kv_len_arr_cpu": kv_len_arr_cpu,
}
)
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens[:bs],
seq_lens_sum, seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs], decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True,
**self.fast_decode_kwargs,
) )
else: else:
raise ValueError(f"Invalid forward mode: {forward_mode=}") raise ValueError(f"Invalid forward mode: {forward_mode=}")
...@@ -336,6 +317,7 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -336,6 +317,7 @@ class FlashInferMLAIndicesUpdaterDecode:
# Buffers and wrappers # Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode self.q_indptr = attn_backend.q_indptr_decode
...@@ -345,8 +327,6 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -345,8 +327,6 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper, decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
): ):
decode_wrapper = decode_wrapper or self.decode_wrapper decode_wrapper = decode_wrapper or self.decode_wrapper
self.call_begin_forward( self.call_begin_forward(
...@@ -356,8 +336,6 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -356,8 +336,6 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum, seq_lens_sum,
self.q_indptr, self.q_indptr,
self.kv_indptr, self.kv_indptr,
init_metadata_replay,
**fast_decode_kwargs,
) )
def call_begin_forward( def call_begin_forward(
...@@ -368,19 +346,14 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -368,19 +346,14 @@ class FlashInferMLAIndicesUpdaterDecode:
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
q_indptr: torch.Tensor, q_indptr: torch.Tensor,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1] q_indptr = q_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = ( kv_indices = torch.empty(
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda") paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
) )
kv_lens = paged_kernel_lens.to(torch.int32) kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling sm_scale = self.scaling
...@@ -393,36 +366,21 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -393,36 +366,21 @@ class FlashInferMLAIndicesUpdaterDecode:
kv_indices, kv_indices,
self.req_to_token.shape[1], self.req_to_token.shape[1],
) )
if not init_metadata_replay:
wrapper.plan( wrapper.plan(
q_indptr, q_indptr,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_lens, kv_lens,
self.num_local_heads, self.num_local_heads,
self.kv_lora_rank, self.kv_lora_rank,
self.qk_rope_head_dim, self.qk_rope_head_dim,
1, 1,
False, False,
sm_scale, sm_scale,
self.data_type, self.data_type,
self.data_type, self.data_type,
) )
else:
wrapper.plan(
fast_decode_kwargs["qo_indptr_cpu"],
fast_decode_kwargs["kv_indptr_cpu"],
kv_indices,
fast_decode_kwargs["kv_len_arr_cpu"],
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
class FlashInferMLAIndicesUpdaterPrefill: class FlashInferMLAIndicesUpdaterPrefill:
...@@ -442,6 +400,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -442,6 +400,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
# Buffers and wrappers # Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
...@@ -538,42 +497,3 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -538,42 +497,3 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.q_data_type, self.q_data_type,
self.data_type, self.data_type,
) )
def fast_mla_decode_plan(
self,
qo_indptr_cpu: torch.Tensor,
kv_indptr_cpu: torch.Tensor,
kv_indices: torch.Tensor,
kv_len_arr_cpu: torch.Tensor,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,
) -> None:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
cuda graph replaying.
"""
self._causal = causal
self._page_size = page_size
self._sm_scale = sm_scale
with self.device as device:
stream = torch.cuda.current_stream(device).cuda_stream
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
stream,
)
...@@ -230,10 +230,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -230,10 +230,9 @@ class TritonAttnBackend(AttentionBackend):
num_tokens: int, num_tokens: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
**kwargs,
): ):
assert encoder_lens is None, "Not supported" assert encoder_lens is None, "Not supported"
...@@ -309,10 +308,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -309,10 +308,9 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
**kwargs,
): ):
# NOTE: encoder_lens expected to be zeros or None # NOTE: encoder_lens expected to be zeros or None
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
......
...@@ -582,9 +582,6 @@ class ScheduleBatch: ...@@ -582,9 +582,6 @@ class ScheduleBatch:
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
# For decode
decode_seq_lens: List[int] = None
# For extend and mixed chunekd prefill # For extend and mixed chunekd prefill
prefix_lens: List[int] = None prefix_lens: List[int] = None
extend_lens: List[int] = None extend_lens: List[int] = None
...@@ -1171,10 +1168,8 @@ class ScheduleBatch: ...@@ -1171,10 +1168,8 @@ class ScheduleBatch:
def get_model_worker_batch(self): def get_model_worker_batch(self):
if self.forward_mode.is_decode_or_idle(): if self.forward_mode.is_decode_or_idle():
decode_seq_lens = self.seq_lens.cpu()
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else: else:
decode_seq_lens = None
extend_seq_lens = self.extend_lens extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens extend_logprob_start_lens = self.extend_logprob_start_lens
...@@ -1199,7 +1194,6 @@ class ScheduleBatch: ...@@ -1199,7 +1194,6 @@ class ScheduleBatch:
top_logprobs_nums=self.top_logprobs_nums, top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens, global_num_tokens=self.global_num_tokens,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
decode_seq_lens=decode_seq_lens,
extend_num_tokens=self.extend_num_tokens, extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
...@@ -1273,9 +1267,6 @@ class ModelWorkerBatch: ...@@ -1273,9 +1267,6 @@ class ModelWorkerBatch:
global_num_tokens: Optional[List[int]] global_num_tokens: Optional[List[int]]
can_run_dp_cuda_graph: bool can_run_dp_cuda_graph: bool
# For decode
decode_seq_lens: Optional[torch.Tensor]
# For extend # For extend
extend_num_tokens: Optional[int] extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]] extend_seq_lens: Optional[List[int]]
......
...@@ -199,10 +199,6 @@ class CudaGraphRunner: ...@@ -199,10 +199,6 @@ class CudaGraphRunner:
if self.enable_torch_compile: if self.enable_torch_compile:
set_torch_compile_config() set_torch_compile_config()
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
# Graph inputs # Graph inputs
with torch.device("cuda"): with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
...@@ -373,9 +369,9 @@ class CudaGraphRunner: ...@@ -373,9 +369,9 @@ class CudaGraphRunner:
num_tokens, num_tokens,
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
encoder_lens,
forward_batch.forward_mode, forward_batch.forward_mode,
encoder_lens=encoder_lens, forward_batch.spec_info,
spec_info=forward_batch.spec_info,
) )
# Run and capture # Run and capture
...@@ -438,7 +434,6 @@ class CudaGraphRunner: ...@@ -438,7 +434,6 @@ class CudaGraphRunner:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(1) self.seq_lens.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
self.seq_lens_cpu.fill_(1)
# Common inputs # Common inputs
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
...@@ -446,8 +441,6 @@ class CudaGraphRunner: ...@@ -446,8 +441,6 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions) self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
if self.is_encoder_decoder: if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
...@@ -463,10 +456,9 @@ class CudaGraphRunner: ...@@ -463,10 +456,9 @@ class CudaGraphRunner:
self.req_pool_indices, self.req_pool_indices,
self.seq_lens, self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs), forward_batch.seq_lens_sum + (bs - raw_bs),
self.encoder_lens,
forward_batch.forward_mode, forward_batch.forward_mode,
encoder_lens=self.encoder_lens, forward_batch.spec_info,
spec_info=forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
) )
# Replay # Replay
......
...@@ -152,9 +152,6 @@ class ForwardBatch: ...@@ -152,9 +152,6 @@ class ForwardBatch:
# Position information # Position information
positions: torch.Tensor = None positions: torch.Tensor = None
# For decode
decode_seq_lens_cpu: Optional[torch.Tensor] = None
# For extend # For extend
extend_num_tokens: Optional[int] = None extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None extend_seq_lens: Optional[torch.Tensor] = None
...@@ -259,8 +256,6 @@ class ForwardBatch: ...@@ -259,8 +256,6 @@ class ForwardBatch:
if ret.forward_mode.is_decode(): if ret.forward_mode.is_decode():
if ret.positions is None: if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens) ret.positions = clamp_position(batch.seq_lens)
if ret.decode_seq_lens_cpu is None:
ret.decode_seq_lens_cpu = batch.decode_seq_lens
else: else:
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment