Unverified Commit edad3731 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix illegal memory access in overlap mode & Use more fused triton kernels for...

Fix illegal memory access in overlap mode & Use more fused triton kernels for building meta data (#2051)
parent 976bc302
...@@ -56,6 +56,7 @@ class BenchArgs: ...@@ -56,6 +56,7 @@ class BenchArgs:
gen_output_len: int = 256 gen_output_len: int = 256
disable_ignore_eos: bool = False disable_ignore_eos: bool = False
seed: int = 1 seed: int = 1
do_not_exit: bool = False
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -143,6 +144,11 @@ class BenchArgs: ...@@ -143,6 +144,11 @@ class BenchArgs:
help="Disable ignore EOS token", help="Disable ignore EOS token",
) )
parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--do-not-exit",
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -309,3 +315,6 @@ if __name__ == "__main__": ...@@ -309,3 +315,6 @@ if __name__ == "__main__":
) )
throughput_test(server_args, bench_args) throughput_test(server_args, bench_args)
while bench_args.do_not_exit:
pass
...@@ -314,7 +314,6 @@ class FlashInferIndicesUpdaterDecode: ...@@ -314,7 +314,6 @@ class FlashInferIndicesUpdaterDecode:
self.head_dim = model_runner.model_config.head_dim self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype self.q_data_type = model_runner.dtype
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend self.attn_backend = attn_backend
...@@ -445,7 +444,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -445,7 +444,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr, kv_indptr,
kv_start_idx, kv_start_idx,
kv_indices, kv_indices,
self.max_context_len, self.req_to_token.shape[1],
) )
wrapper.end_forward() wrapper.end_forward()
...@@ -474,7 +473,6 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -474,7 +473,6 @@ class FlashInferIndicesUpdaterPrefill:
self.head_dim = model_runner.model_config.head_dim self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype self.q_data_type = model_runner.dtype
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend self.attn_backend = attn_backend
...@@ -599,7 +597,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -599,7 +597,7 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr, kv_indptr,
kv_start_idx, kv_start_idx,
kv_indices, kv_indices,
self.max_context_len, self.req_to_token.shape[1],
) )
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
...@@ -638,10 +636,11 @@ def create_flashinfer_kv_indices_triton( ...@@ -638,10 +636,11 @@ def create_flashinfer_kv_indices_triton(
kv_indptr, kv_indptr,
kv_start_idx, kv_start_idx,
kv_indices_ptr, kv_indices_ptr,
max_context_len: tl.constexpr, req_to_token_ptr_stride: tl.constexpr,
): ):
BLOCK_SIZE: tl.constexpr = 512 BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid) req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid) kv_indices_offset = tl.load(kv_indptr + pid)
...@@ -652,15 +651,15 @@ def create_flashinfer_kv_indices_triton( ...@@ -652,15 +651,15 @@ def create_flashinfer_kv_indices_triton(
kv_end = kv_start kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
req_to_token_ptr += req_pool_index * max_context_len
kv_indices_ptr += kv_indices_offset
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
st_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop): for i in range(num_loop):
mask = ld_offset < kv_end offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
data = tl.load(req_to_token_ptr + ld_offset, mask=mask) mask = offset < kv_end - kv_start
tl.store(kv_indices_ptr + st_offset, data, mask=mask) data = tl.load(
ld_offset += BLOCK_SIZE req_to_token_ptr
st_offset += BLOCK_SIZE + req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
...@@ -62,21 +62,21 @@ class LogitsMetadata: ...@@ -62,21 +62,21 @@ class LogitsMetadata:
@classmethod @classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch): def from_forward_batch(cls, forward_batch: ForwardBatch):
extend_logprob_pruned_lens_cpu = None
if forward_batch.return_logprob: if forward_batch.return_logprob:
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
else:
return_top_logprob = False
if forward_batch.forward_mode.is_extend(): if forward_batch.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [ extend_logprob_pruned_lens_cpu = [
extend_len - start_len extend_len - start_len
for extend_len, start_len in zip( for extend_len, start_len in zip(
forward_batch.extend_seq_lens, forward_batch.extend_seq_lens_cpu,
forward_batch.extend_logprob_start_lens_cpu, forward_batch.extend_logprob_start_lens_cpu,
) )
] ]
else: else:
extend_logprob_pruned_lens_cpu = None return_top_logprob = False
return cls( return cls(
forward_mode=forward_batch.forward_mode, forward_mode=forward_batch.forward_mode,
top_logprobs_nums=forward_batch.top_logprobs_nums, top_logprobs_nums=forward_batch.top_logprobs_nums,
......
...@@ -34,6 +34,8 @@ import logging ...@@ -34,6 +34,8 @@ import logging
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import triton
import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
...@@ -615,12 +617,12 @@ class ScheduleBatch: ...@@ -615,12 +617,12 @@ class ScheduleBatch:
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids) extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = [] seq_lens = []
pre_lens = []
# Allocate memory # Allocate memory
req_pool_indices = self.alloc_req_slots(bs) req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens) out_cache_loc = self.alloc_token_slots(extend_num_tokens)
pt = 0
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
already_computed = ( already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens req.extend_logprob_start_len + 1 + req.cached_tokens
...@@ -638,10 +640,6 @@ class ScheduleBatch: ...@@ -638,10 +640,6 @@ class ScheduleBatch:
self.req_to_token_pool.write( self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
) )
self.req_to_token_pool.write(
(req.req_pool_idx, slice(pre_len, seq_len)),
out_cache_loc[pt : pt + req.extend_input_len],
)
# Compute the relative logprob_start_len in an extend batch # Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len: if req.logprob_start_len >= pre_len:
...@@ -652,8 +650,8 @@ class ScheduleBatch: ...@@ -652,8 +650,8 @@ class ScheduleBatch:
extend_logprob_start_len = req.extend_input_len - 1 extend_logprob_start_len = req.extend_input_len - 1
req.extend_logprob_start_len = extend_logprob_start_len req.extend_logprob_start_len = extend_logprob_start_len
pt += req.extend_input_len
req.is_retracted = False req.is_retracted = False
pre_lens.append(pre_len)
# Set fields # Set fields
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
...@@ -665,7 +663,6 @@ class ScheduleBatch: ...@@ -665,7 +663,6 @@ class ScheduleBatch:
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens) self.seq_lens_sum = sum(seq_lens)
...@@ -676,9 +673,33 @@ class ScheduleBatch: ...@@ -676,9 +673,33 @@ class ScheduleBatch:
self.extend_lens = [r.extend_input_len for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
# Write to req_to_token_pool
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
write_req_to_token_pool_triton[(bs,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
pre_lens,
self.seq_lens,
extend_lens,
self.out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
)
# The triton kernel is equivalent to the following python code.
# self.req_to_token_pool.write(
# (req.req_pool_idx, slice(pre_len, seq_len)),
# out_cache_loc[pt : pt + req.extend_input_len],
# )
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens) self.prepare_encoder_info_extend(input_ids, seq_lens)
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch( self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, self,
self.model_config.vocab_size, self.model_config.vocab_size,
...@@ -1025,6 +1046,9 @@ class ScheduleBatch: ...@@ -1025,6 +1046,9 @@ class ScheduleBatch:
) )
def copy(self): def copy(self):
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = self.seq_lens[0].item()
# Only contain fields that will be used by process_batch_result # Only contain fields that will be used by process_batch_result
return ScheduleBatch( return ScheduleBatch(
reqs=self.reqs, reqs=self.reqs,
...@@ -1104,3 +1128,40 @@ class ModelWorkerBatch: ...@@ -1104,3 +1128,40 @@ class ModelWorkerBatch:
for x, y in self.req_to_token_pool_records for x, y in self.req_to_token_pool_records
] ]
self.sampling_info.to(device) self.sampling_info.to(device)
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)
...@@ -56,6 +56,7 @@ class TpModelWorkerClient: ...@@ -56,6 +56,7 @@ class TpModelWorkerClient:
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port) self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
self.max_running_requests = self.worker.max_running_requests self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device self.device = self.worker.device
self.gpu_id = gpu_id
# Init future mappings # Init future mappings
self.future_token_ids_ct = 0 self.future_token_ids_ct = 0
...@@ -73,12 +74,6 @@ class TpModelWorkerClient: ...@@ -73,12 +74,6 @@ class TpModelWorkerClient:
) )
self.forward_thread.start() self.forward_thread.start()
self.copy_queue = Queue()
self.copy_thread = threading.Thread(
target=self.copy_thread_func,
)
self.copy_thread.start()
def get_worker_info(self): def get_worker_info(self):
return self.worker.get_worker_info() return self.worker.get_worker_info()
...@@ -104,12 +99,11 @@ class TpModelWorkerClient: ...@@ -104,12 +99,11 @@ class TpModelWorkerClient:
@torch.inference_mode() @torch.inference_mode()
def forward_thread_func_(self): def forward_thread_func_(self):
while True: while True:
self.has_inflight_batch = False
model_worker_batch, future_token_ids_ct = self.input_queue.get() model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch: if not model_worker_batch:
break break
self.has_inflight_batch = True
self.launch_event = threading.Event() self.launch_event = threading.Event()
copy_event = torch.cuda.Event()
# Resolve future tokens in the input # Resolve future tokens in the input
input_ids = model_worker_batch.input_ids input_ids = model_worker_batch.input_ids
...@@ -142,19 +136,16 @@ class TpModelWorkerClient: ...@@ -142,19 +136,16 @@ class TpModelWorkerClient:
) )
) )
next_token_ids = next_token_ids.to("cpu", non_blocking=True) next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True)
copy_event.record() copy_event.record()
self.launch_event.set() self.launch_event.set()
self.copy_queue.put((copy_event, logits_output, next_token_ids)) self.output_queue.put((copy_event, logits_output, next_token_ids))
def copy_thread_func(self): def resulve_batch_result(self, bid: int):
while True: copy_event, logits_output, next_token_ids = self.output_queue.get()
copy_event, logits_output, next_token_ids = self.copy_queue.get()
if not copy_event:
break
while not copy_event.query(): while not copy_event.query():
time.sleep(1e-5) time.sleep(1e-5)
self.launch_event.wait()
if logits_output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = ( logits_output.next_token_logprobs = (
...@@ -167,14 +158,7 @@ class TpModelWorkerClient: ...@@ -167,14 +158,7 @@ class TpModelWorkerClient:
logits_output.normalized_prompt_logprobs = ( logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist() logits_output.normalized_prompt_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist()
self.output_queue.put((logits_output, next_token_ids.tolist()))
def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
if self.has_inflight_batch:
# Wait until the batch is launched
self.launch_event.wait()
return logits_output, next_token_ids return logits_output, next_token_ids
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
......
...@@ -36,6 +36,8 @@ from enum import IntEnum, auto ...@@ -36,6 +36,8 @@ from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
...@@ -236,25 +238,16 @@ class ForwardBatch: ...@@ -236,25 +238,16 @@ class ForwardBatch:
# Init position information # Init position information
if not ret.forward_mode.is_decode(): if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor( ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32 batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens) ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0) ret.positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
)
ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
...@@ -271,3 +264,72 @@ class ForwardBatch: ...@@ -271,3 +264,72 @@ class ForwardBatch:
model_runner.lora_manager.prepare_lora_batch(ret) model_runner.lora_manager.prepare_lora_batch(ret)
return ret return ret
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
):
"""Compute positions. It is a fused version of `compute_position_torch`."""
batch_size = extend_seq_lens.shape[0]
positions = torch.empty(
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
)
extend_start_loc = torch.empty(
batch_size, dtype=torch.int32, device=extend_seq_lens.device
)
# Launch kernel
compute_position_kernel[(batch_size,)](
positions,
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
)
return positions, extend_start_loc
@triton.jit
def compute_position_kernel(
positions,
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_seq_lens + i)
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
tl.store(
positions + cumsum_start + offset,
prefix_len + offset,
mask=offset < seq_len,
)
tl.store(extend_start_loc + pid, cumsum_start)
def compute_position_torch(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
):
positions = torch.concat(
[
torch.arange(
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
)
for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
],
axis=0,
)
extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc
...@@ -73,7 +73,7 @@ class SamplingBatchInfo: ...@@ -73,7 +73,7 @@ class SamplingBatchInfo:
top_ks=top_ks, top_ks=top_ks,
min_ps=min_ps, min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
is_all_greedy=top_ks.max().item() <= 1, is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
vocab_size=vocab_size, vocab_size=vocab_size,
device=device, device=device,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment