Unverified Commit a837166e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix select and normalized logprobs (#67)

parent 11f3cca6
import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
@triton.jit
def _fwd_segmented_gather(
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
cur_req = tl.program_id(0)
cur_l = tl.load(len_add_1 + cur_req)
cum_l = tl.load(cum_len + cur_req)
for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE):
off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = off < cur_l - 1
idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask)
data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask)
tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask)
cached_kernel = None
def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs):
cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0)
voc_size = all_logits.shape[1]
grid = (len_add_1.shape[0], 1, 1)
max_seq_len = len_add_1.max().item()
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
4,
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
)
return
_fwd_segmented_gather[grid](
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size,
BLOCK_SIZE=128,
)
cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather)
if __name__ == "__main__":
all_logits = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
dtype=torch.float32,
device="cuda",
)
len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
logprobs = torch.empty((3), dtype=torch.float32, device="cuda")
get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs)
print(logprobs)
# assert logprobs == [2, 2, 4]
import torch
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from torch import nn
from vllm.model_executor.parallel_utils.communication_op import (
......@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module):
normalized_logprobs = compute_normalized_logprobs(
all_logprobs,
input_metadata.seq_lens - input_metadata.prefix_lens,
input_ids,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
)
last_logits = logits[last_index]
return last_logits, normalized_logprobs
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
logprobs = torch.zeros(
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
start = start_loc.clone()
end = start + seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
return sum_logp / ((seq_lens - 1).clamp(min=1))
if __name__ == "__main__":
all_logprobs = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
dtype=torch.float32,
device="cuda",
)
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs)
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
end = torch.cumsum(len_add_1.sub_(1), dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
end.sub_(1)
torch.cuda.synchronize()
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
res = sum_logp / len_add_1
return res
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
len_cumsum = torch.cumsum(seq_lens, dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
end = start + seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
# assert logprobs == [2, _, 2, 4, _]
print("logprobs", logprobs)
print("start", start)
print("end", end)
print("sum_logp", sum_logp)
import logging
from dataclasses import dataclass
from enum import Enum, auto
from typing import List
import logging
import numpy as np
import torch
......@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
logger = logging.getLogger("model_runner")
......@@ -112,7 +111,7 @@ class InputMetadata:
def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod
......@@ -262,7 +261,7 @@ class ModelRunner:
if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}")
logger.info("load weight begin.")
logger.info(f"Rank {self.tp_rank}: load weight begin.")
# Load weights
linear_method = None
......@@ -287,7 +286,7 @@ class ModelRunner:
)
self.model = model.eval()
logger.info("load weight end.")
logger.info(f"Rank {self.tp_rank}: load weight end.")
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
......@@ -308,8 +307,9 @@ class ModelRunner:
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
if self.max_total_num_token <= 0:
raise RuntimeError("Not enought memory. "
"Please try to increase --mem-fraction-static.")
raise RuntimeError(
"Not enought memory. " "Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool(
int(self.max_total_num_token / self.model_config.context_len * 256),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment