Unverified Commit 3d8f1c9b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use int64 as indices for set_kv_buffer (#3039)

parent a42213db
...@@ -99,10 +99,7 @@ class BenchArgs: ...@@ -99,10 +99,7 @@ class BenchArgs:
parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
parser.add_argument( parser.add_argument(
"--profile", "--profile", action="store_true", help="Use Torch Profiler."
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
) )
parser.add_argument( parser.add_argument(
"--profile-filename-prefix", "--profile-filename-prefix",
...@@ -381,6 +378,7 @@ def latency_test_run_once( ...@@ -381,6 +378,7 @@ def latency_test_run_once(
parent_dir = os.path.dirname(os.path.abspath(profile_filename)) parent_dir = os.path.dirname(os.path.abspath(profile_filename))
os.makedirs(parent_dir, exist_ok=True) os.makedirs(parent_dir, exist_ok=True)
profiler.export_chrome_trace(profile_filename) profiler.export_chrome_trace(profile_filename)
rank_print(f"torch profiler chrome trace saved to {profile_filename}")
# Record decode timing from 2nd output # Record decode timing from 2nd output
if output_len > 1: if output_len > 1:
...@@ -451,7 +449,7 @@ def latency_test( ...@@ -451,7 +449,7 @@ def latency_test(
il, il,
ol, ol,
server_args.device, server_args.device,
bench_args.profile, bench_args.profile if tp_rank == 0 else None,
bench_args.profile_filename_prefix, bench_args.profile_filename_prefix,
) )
if ret is not None: if ret is not None:
......
...@@ -296,7 +296,7 @@ def fused_softcap_kernel( ...@@ -296,7 +296,7 @@ def fused_softcap_kernel(
n_elements, n_elements,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
pid = tl.program_id(0) pid = tl.program_id(0).to(tl.int64)
block_start = pid * BLOCK_SIZE block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE) offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements mask = offsets < n_elements
......
import logging import logging
from typing import Dict, List from typing import List
import torch import torch
from torch import nn from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
...@@ -109,8 +108,6 @@ class Sampler(nn.Module): ...@@ -109,8 +108,6 @@ class Sampler(nn.Module):
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
) )
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
# Attach logprobs to logits_output (in-place modification) # Attach logprobs to logits_output (in-place modification)
if return_logprob: if return_logprob:
if any(x > 0 for x in top_logprobs_nums): if any(x > 0 for x in top_logprobs_nums):
...@@ -124,7 +121,7 @@ class Sampler(nn.Module): ...@@ -124,7 +121,7 @@ class Sampler(nn.Module):
batch_next_token_ids, batch_next_token_ids,
] ]
return batch_next_token_ids return batch_next_token_ids.to(torch.int32)
def _apply_custom_logit_processor( def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
......
...@@ -550,13 +550,13 @@ class ScheduleBatch: ...@@ -550,13 +550,13 @@ class ScheduleBatch:
next_batch_sampling_info: SamplingBatchInfo = None next_batch_sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner # Batched arguments to model runner
input_ids: torch.Tensor = None input_ids: torch.Tensor = None # shape: [b], int32
input_embeds: torch.Tensor = None input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
req_pool_indices: torch.Tensor = None req_pool_indices: torch.Tensor = None # shape: [b], int32
seq_lens: torch.Tensor = None seq_lens: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache # The output locations of the KV cache
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None # shape: [b], int32
output_ids: torch.Tensor = None output_ids: torch.Tensor = None # shape: [b], int32
# The sum of all sequence lengths # The sum of all sequence lengths
seq_lens_sum: int = None seq_lens_sum: int = None
...@@ -1026,7 +1026,7 @@ class ScheduleBatch: ...@@ -1026,7 +1026,7 @@ class ScheduleBatch:
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0 self.seq_lens_sum = 0
self.extend_num_tokens = 0 self.extend_num_tokens = 0
self.sampling_info = SamplingBatchInfo.from_schedule_batch( self.sampling_info = SamplingBatchInfo.from_schedule_batch(
......
...@@ -24,7 +24,7 @@ import tqdm ...@@ -24,7 +24,7 @@ import tqdm
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import graph_capture from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.layers.torchao_utils import save_gemlite_cache
...@@ -63,7 +63,7 @@ def patch_model( ...@@ -63,7 +63,7 @@ def patch_model(
model: torch.nn.Module, model: torch.nn.Module,
enable_compile: bool, enable_compile: bool,
batch_size: int, batch_size: int,
tp_group: "GroupCoordinator", tp_group: GroupCoordinator,
): ):
"""Patch the model to make it compatible with with torch.compile""" """Patch the model to make it compatible with with torch.compile"""
backup_ca_comm = None backup_ca_comm = None
...@@ -149,9 +149,18 @@ class CudaGraphRunner: ...@@ -149,9 +149,18 @@ class CudaGraphRunner:
and bs <= model_runner.server_args.cuda_graph_max_bs and bs <= model_runner.server_args.cuda_graph_max_bs
] ]
self.compile_bs = (
[
bs
for bs in self.capture_bs
if bs <= self.model_runner.server_args.torch_compile_max_bs
]
if self.use_torch_compile
else []
)
self.capture_forward_mode = ForwardMode.DECODE self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1 self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle(): if model_runner.spec_algorithm.is_eagle():
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
self.num_tokens_per_bs = ( self.num_tokens_per_bs = (
...@@ -163,16 +172,6 @@ class CudaGraphRunner: ...@@ -163,16 +172,6 @@ class CudaGraphRunner:
self.model_runner.server_args.speculative_num_draft_tokens self.model_runner.server_args.speculative_num_draft_tokens
) )
self.compile_bs = (
[
bs
for bs in self.capture_bs
if bs <= self.model_runner.server_args.torch_compile_max_bs
]
if self.use_torch_compile
else []
)
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs self.max_num_token = self.max_bs * self.num_tokens_per_bs
...@@ -180,7 +179,6 @@ class CudaGraphRunner: ...@@ -180,7 +179,6 @@ class CudaGraphRunner:
self.seq_len_fill_value = ( self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0 self.encoder_len_fill_value = 0
...@@ -189,14 +187,14 @@ class CudaGraphRunner: ...@@ -189,14 +187,14 @@ class CudaGraphRunner:
# Common inputs # Common inputs
with torch.device("cuda"): with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32) self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full( self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
) )
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
# Speculative_inference # Speculative_inference
if model_runner.spec_algorithm.is_eagle(): if model_runner.spec_algorithm.is_eagle():
......
...@@ -38,7 +38,7 @@ import triton ...@@ -38,7 +38,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import maybe_torch_compile from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
...@@ -415,6 +415,6 @@ def compute_position_torch( ...@@ -415,6 +415,6 @@ def compute_position_torch(
return positions.to(torch.int64), extend_start_loc return positions.to(torch.int64), extend_start_loc
@maybe_torch_compile(dynamic=True) @torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens): def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64) return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment