Unverified Commit a837166e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix select and normalized logprobs (#67)

parent 11f3cca6
import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
@triton.jit
def _fwd_segmented_gather(
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
cur_req = tl.program_id(0)
cur_l = tl.load(len_add_1 + cur_req)
cum_l = tl.load(cum_len + cur_req)
for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE):
off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = off < cur_l - 1
idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask)
data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask)
tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask)
cached_kernel = None
def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs):
cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0)
voc_size = all_logits.shape[1]
grid = (len_add_1.shape[0], 1, 1)
max_seq_len = len_add_1.max().item()
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
4,
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
)
return
_fwd_segmented_gather[grid](
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size,
BLOCK_SIZE=128,
)
cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather)
if __name__ == "__main__":
all_logits = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
dtype=torch.float32,
device="cuda",
)
len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
logprobs = torch.empty((3), dtype=torch.float32, device="cuda")
get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs)
print(logprobs)
# assert logprobs == [2, 2, 4]
import torch import torch
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from torch import nn from torch import nn
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
...@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module): ...@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module):
normalized_logprobs = compute_normalized_logprobs( normalized_logprobs = compute_normalized_logprobs(
all_logprobs, all_logprobs,
input_metadata.seq_lens - input_metadata.prefix_lens,
input_ids, input_ids,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
) )
last_logits = logits[last_index] last_logits = logits[last_index]
return last_logits, normalized_logprobs return last_logits, normalized_logprobs
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids): def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1) logprobs = all_logprobs[
logprobs = torch.zeros( torch.arange(all_logprobs.shape[0], device="cuda"),
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda" torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
start = start_loc.clone()
end = start + seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
return sum_logp / ((seq_lens - 1).clamp(min=1))
if __name__ == "__main__":
all_logprobs = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
dtype=torch.float32,
device="cuda",
) )
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs) seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
end = torch.cumsum(len_add_1.sub_(1), dim=0) logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
end.sub_(1) logprobs = all_logprobs[
torch.cuda.synchronize() torch.arange(all_logprobs.shape[0], device="cuda"),
sum_logp = cumsum[end] - cumsum[start] + logprobs[start] torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
res = sum_logp / len_add_1 ]
return res logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
len_cumsum = torch.cumsum(seq_lens, dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
end = start + seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
# assert logprobs == [2, _, 2, 4, _]
print("logprobs", logprobs)
print("start", start)
print("end", end)
print("sum_logp", sum_logp)
import logging
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import List from typing import List
import logging
import numpy as np import numpy as np
import torch import torch
...@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig ...@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
...@@ -112,7 +111,7 @@ class InputMetadata: ...@@ -112,7 +111,7 @@ class InputMetadata:
def init_extend_args(self): def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.max_extend_len = int(torch.max(self.extend_seq_lens)) self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod @classmethod
...@@ -262,7 +261,7 @@ class ModelRunner: ...@@ -262,7 +261,7 @@ class ModelRunner:
if model_class is None: if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}") raise ValueError(f"Unsupported architectures: {architectures}")
logger.info("load weight begin.") logger.info(f"Rank {self.tp_rank}: load weight begin.")
# Load weights # Load weights
linear_method = None linear_method = None
...@@ -287,7 +286,7 @@ class ModelRunner: ...@@ -287,7 +286,7 @@ class ModelRunner:
) )
self.model = model.eval() self.model = model.eval()
logger.info("load weight end.") logger.info(f"Rank {self.tp_rank}: load weight end.")
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
...@@ -308,8 +307,9 @@ class ModelRunner: ...@@ -308,8 +307,9 @@ class ModelRunner:
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory) self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
if self.max_total_num_token <= 0: if self.max_total_num_token <= 0:
raise RuntimeError("Not enought memory. " raise RuntimeError(
"Please try to increase --mem-fraction-static.") "Not enought memory. " "Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
int(self.max_total_num_token / self.model_config.context_len * 256), int(self.max_total_num_token / self.model_config.context_len * 256),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment