Unverified Commit edad3731 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix illegal memory access in overlap mode & Use more fused triton kernels for...

Fix illegal memory access in overlap mode & Use more fused triton kernels for building meta data (#2051)
parent 976bc302
......@@ -56,6 +56,7 @@ class BenchArgs:
gen_output_len: int = 256
disable_ignore_eos: bool = False
seed: int = 1
do_not_exit: bool = False
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
......@@ -143,6 +144,11 @@ class BenchArgs:
help="Disable ignore EOS token",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--do-not-exit",
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......@@ -309,3 +315,6 @@ if __name__ == "__main__":
)
throughput_test(server_args, bench_args)
while bench_args.do_not_exit:
pass
......@@ -314,7 +314,6 @@ class FlashInferIndicesUpdaterDecode:
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
......@@ -445,7 +444,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr,
kv_start_idx,
kv_indices,
self.max_context_len,
self.req_to_token.shape[1],
)
wrapper.end_forward()
......@@ -474,7 +473,6 @@ class FlashInferIndicesUpdaterPrefill:
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
......@@ -599,7 +597,7 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr,
kv_start_idx,
kv_indices,
self.max_context_len,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
......@@ -638,10 +636,11 @@ def create_flashinfer_kv_indices_triton(
kv_indptr,
kv_start_idx,
kv_indices_ptr,
max_context_len: tl.constexpr,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
......@@ -652,15 +651,15 @@ def create_flashinfer_kv_indices_triton(
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
req_to_token_ptr += req_pool_index * max_context_len
kv_indices_ptr += kv_indices_offset
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
st_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = ld_offset < kv_end
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
ld_offset += BLOCK_SIZE
st_offset += BLOCK_SIZE
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
......@@ -62,21 +62,21 @@ class LogitsMetadata:
@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
extend_logprob_pruned_lens_cpu = None
if forward_batch.return_logprob:
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if forward_batch.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [
extend_len - start_len
for extend_len, start_len in zip(
forward_batch.extend_seq_lens_cpu,
forward_batch.extend_logprob_start_lens_cpu,
)
]
else:
return_top_logprob = False
if forward_batch.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [
extend_len - start_len
for extend_len, start_len in zip(
forward_batch.extend_seq_lens,
forward_batch.extend_logprob_start_lens_cpu,
)
]
else:
extend_logprob_pruned_lens_cpu = None
return cls(
forward_mode=forward_batch.forward_mode,
top_logprobs_nums=forward_batch.top_logprobs_nums,
......
......@@ -34,6 +34,8 @@ import logging
from typing import List, Optional, Tuple, Union
import torch
import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
......@@ -615,12 +617,12 @@ class ScheduleBatch:
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = []
pre_lens = []
# Allocate memory
req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
pt = 0
for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
......@@ -638,10 +640,6 @@ class ScheduleBatch:
self.req_to_token_pool.write(
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(pre_len, seq_len)),
out_cache_loc[pt : pt + req.extend_input_len],
)
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
......@@ -652,8 +650,8 @@ class ScheduleBatch:
extend_logprob_start_len = req.extend_input_len - 1
req.extend_logprob_start_len = extend_logprob_start_len
pt += req.extend_input_len
req.is_retracted = False
pre_lens.append(pre_len)
# Set fields
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
......@@ -665,7 +663,6 @@ class ScheduleBatch:
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
......@@ -676,9 +673,33 @@ class ScheduleBatch:
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
# Write to req_to_token_pool
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
write_req_to_token_pool_triton[(bs,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
pre_lens,
self.seq_lens,
extend_lens,
self.out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
)
# The triton kernel is equivalent to the following python code.
# self.req_to_token_pool.write(
# (req.req_pool_idx, slice(pre_len, seq_len)),
# out_cache_loc[pt : pt + req.extend_input_len],
# )
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens)
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
......@@ -1025,6 +1046,9 @@ class ScheduleBatch:
)
def copy(self):
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = self.seq_lens[0].item()
# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
......@@ -1104,3 +1128,40 @@ class ModelWorkerBatch:
for x, y in self.req_to_token_pool_records
]
self.sampling_info.to(device)
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)
......@@ -56,6 +56,7 @@ class TpModelWorkerClient:
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
self.gpu_id = gpu_id
# Init future mappings
self.future_token_ids_ct = 0
......@@ -73,12 +74,6 @@ class TpModelWorkerClient:
)
self.forward_thread.start()
self.copy_queue = Queue()
self.copy_thread = threading.Thread(
target=self.copy_thread_func,
)
self.copy_thread.start()
def get_worker_info(self):
return self.worker.get_worker_info()
......@@ -104,12 +99,11 @@ class TpModelWorkerClient:
@torch.inference_mode()
def forward_thread_func_(self):
while True:
self.has_inflight_batch = False
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
self.has_inflight_batch = True
self.launch_event = threading.Event()
copy_event = torch.cuda.Event()
# Resolve future tokens in the input
input_ids = model_worker_batch.input_ids
......@@ -142,39 +136,29 @@ class TpModelWorkerClient:
)
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True)
copy_event.record()
self.launch_event.set()
self.copy_queue.put((copy_event, logits_output, next_token_ids))
def copy_thread_func(self):
while True:
copy_event, logits_output, next_token_ids = self.copy_queue.get()
if not copy_event:
break
while not copy_event.query():
time.sleep(1e-5)
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
self.output_queue.put((logits_output, next_token_ids.tolist()))
self.output_queue.put((copy_event, logits_output, next_token_ids))
def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
if self.has_inflight_batch:
# Wait until the batch is launched
self.launch_event.wait()
copy_event, logits_output, next_token_ids = self.output_queue.get()
while not copy_event.query():
time.sleep(1e-5)
self.launch_event.wait()
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
......
......@@ -36,6 +36,8 @@ from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional
import torch
import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
......@@ -236,25 +238,16 @@ class ForwardBatch:
# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_num_tokens = batch.extend_num_tokens
ret.positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
)
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
......@@ -271,3 +264,72 @@ class ForwardBatch:
model_runner.lora_manager.prepare_lora_batch(ret)
return ret
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
):
"""Compute positions. It is a fused version of `compute_position_torch`."""
batch_size = extend_seq_lens.shape[0]
positions = torch.empty(
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
)
extend_start_loc = torch.empty(
batch_size, dtype=torch.int32, device=extend_seq_lens.device
)
# Launch kernel
compute_position_kernel[(batch_size,)](
positions,
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
)
return positions, extend_start_loc
@triton.jit
def compute_position_kernel(
positions,
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_seq_lens + i)
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
tl.store(
positions + cumsum_start + offset,
prefix_len + offset,
mask=offset < seq_len,
)
tl.store(extend_start_loc + pid, cumsum_start)
def compute_position_torch(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
):
positions = torch.concat(
[
torch.arange(
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
)
for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
],
axis=0,
)
extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc
......@@ -73,7 +73,7 @@ class SamplingBatchInfo:
top_ks=top_ks,
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
is_all_greedy=top_ks.max().item() <= 1,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
vocab_size=vocab_size,
device=device,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment