Commit 22085081 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files
parent f6d40df0
import torch
import triton
import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
@triton.jit
def _fwd_kernel(
Q_Extend,
K_Extend,
V_Extend,
O_Extend,
K_Buffer,
V_Buffer,
Req_to_tokens,
B_req_idx,
B_Seq_Len,
B_Start_Loc_Extend,
B_Seq_Len_Extend,
sm_scale,
kv_group_num,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_req_to_tokens_b,
BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
cur_block_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
cur_seq_prefix_start_in_loc = 0
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = tl.arange(0, BLOCK_M)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
offs_q = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
# stage1: compute scores with prefix
offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_len_prefix
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
cur_seq_prefix_start_in_loc + start_n + offs_n
)
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
# load k in transposed way
offs_buf_k = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_d[None, :]
)
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
# stage2: compute the trianlge part
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
for start_n in range(0, cur_block_m_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_block_m_end
# load k in transposed way
offs_k = (
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None]
)
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_causual, qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_v = (
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_d[None, :]
)
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
offs_o = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
)
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
def extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
BLOCK_M, BLOCK_N = 128, 128
Lq, Lk, Lv, Lo = (
q_extend.shape[-1],
k_extend.shape[-1],
v_extend.shape[-1],
o_extend.shape[-1],
)
assert Lq == Lk and Lk == Lv and Lv == Lo
assert Lq in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5)
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_warps = 4 if Lk <= 64 else 8
num_stages = 1
_fwd_kernel[grid](
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
BLOCK_DMODEL=Lq,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
def redundant_attention(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
):
total_token_num = k_buffer.shape[0]
B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1]
q_buffer = torch.empty(
(total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device
)
pt = 0
for i in range(B):
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
q_buffer[pl:pr] = q_extend[pt : pt + cur_seq_len_extend]
pt += cur_seq_len_extend
o_buffer = torch.empty_like(q_buffer)
context_attention_fwd(
q_buffer, k_buffer, v_buffer, o_buffer, b_start_loc, b_seq_len, max_len_in_batch
)
pt = 0
for i in range(B):
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
pt += cur_seq_len_extend
def test():
torch.manual_seed(0)
B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128
dtype = torch.float16
b_seq_len_prefix = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len_extend = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda")
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
for i in range(B):
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
b_seq_len_extend = b_seq_len - b_seq_len_prefix
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
)
redundant_attention(
q_extend,
k_extend,
v_extend,
o_redundant,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
)
print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))
assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
if __name__ == "__main__":
test()
import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
@triton.jit
def _fwd_segmented_gather(
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
cur_req = tl.program_id(0)
cur_l = tl.load(len_add_1 + cur_req)
cum_l = tl.load(cum_len + cur_req)
for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE):
off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = off < cur_l - 1
idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask)
data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask)
tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask)
cached_kernel = None
def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs):
cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0)
voc_size = all_logits.shape[1]
grid = (len_add_1.shape[0], 1, 1)
max_seq_len = len_add_1.max().item()
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
4,
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
)
return
_fwd_segmented_gather[grid](
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size,
BLOCK_SIZE=128,
)
cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather)
if __name__ == "__main__":
all_logits = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
dtype=torch.float32,
device="cuda",
)
len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
logprobs = torch.empty((3), dtype=torch.float32, device="cuda")
get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs)
print(logprobs)
# assert logprobs == [2, 2, 4]
import torch
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from torch import nn
from vllm.model_executor.parallel_utils.communication_op import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
class LogitsProcessor(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
def forward(self, input_ids, hidden_states, weight, input_metadata):
if not input_metadata.return_normalized_logprob:
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
last_hidden = hidden_states[last_index]
hidden_states = None
last_logits = torch.matmul(last_hidden, weight.T)
if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, None
else:
assert input_metadata.forward_mode != ForwardMode.DECODE
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
logits = tensor_model_parallel_all_gather(logits)
logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
normalized_logprobs = compute_normalized_logprobs(
all_logprobs,
input_metadata.seq_lens - input_metadata.prefix_lens,
input_ids,
)
last_logits = logits[last_index]
return last_logits, normalized_logprobs
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0]
logprobs = torch.zeros(
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
)
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs)
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
end = torch.cumsum(len_add_1.sub_(1), dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
end.sub_(1)
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
res = sum_logp / len_add_1
return res
from typing import List
import torch
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from torch import nn
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
class RadixAttention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scaling,
num_kv_heads,
layer_id,
):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.layer_id = layer_id
from sglang.srt.managers.router.model_runner import global_model_mode
self.use_flashinfer = "flashinfer" in global_model_mode
if self.use_flashinfer:
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
else:
self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
context_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
k,
v,
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
)
self.store_kv_cache(k, v, input_metadata)
return o
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.prefix_lens,
input_metadata.extend_start_loc,
input_metadata.extend_seq_lens,
input_metadata.max_seq_len,
input_metadata.max_extend_len,
)
return o
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
token_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
input_metadata.other_kv_index,
input_metadata.total_num_tokens,
)
return o
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)
o = input_metadata.prefill_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.qo_indptr,
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
allow_fp16_qk_reduction=True,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)
o = input_metadata.decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
def forward(self, q, k, v, input_metadata: InputMetadata):
k = k.view(-1, self.tp_k_head_num, self.head_dim)
v = v.view(-1, self.tp_v_head_num, self.head_dim)
if input_metadata.forward_mode == ForwardMode.PREFILL:
return self.prefill_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode == ForwardMode.EXTEND:
return self.extend_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
if input_metadata.out_cache_loc is not None:
key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer[input_metadata.out_cache_loc] = cache_v
elif input_metadata.out_cache_cont_start is not None:
key_buffer[
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
] = cache_k
value_buffer[
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
] = cache_v
else:
raise RuntimeError()
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
@triton.jit
def _fwd_kernel_stage1(
Q,
K_Buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
att_stride_h,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_start_index = 0
cur_batch_end_index = cur_batch_seq_len
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark)
offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
offs_buf_k = (
k_loc[:, None] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[None, :]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=offs_n_new[:, None] < cur_batch_end_index,
other=0.0,
)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
@triton.jit
def _fwd_kernel_stage2(
Logics,
V_Buffer,
Out,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
stride_logic_h,
stride_buf_vbs,
stride_buf_vh,
stride_obs,
stride_oh,
stride_req_to_token_b,
other_kv_index, # To fix a NAN issue
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
v_ptrs = V_Buffer + offs_buf_v
e_max = float("-inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
v_index = tl.load(
Req_to_tokens
+ cur_batch_req_idx * stride_req_to_token_b
+ (start_n + offs_n),
mask=(start_n + offs_n) < cur_batch_seq_len,
other=other_kv_index,
)
qk = tl.load(
Logics
+ cur_head * stride_logic_h
+ (cur_batch_start_loc + start_n + offs_n),
mask=start_n + offs_n < cur_batch_seq_len,
other=float("-inf"),
)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
e_sum = e_sum * old_scale + tl.sum(p, 0)
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
e_max = n_e_max
acc = acc / e_sum
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
cached_kernel_stage1 = None
cached_kernel_stage2 = None
def _token_att_m_fwd(
q,
k_buffer,
att_out,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
max_len_in_batch,
):
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lk**0.5)
batch, head_num = B_req_idx.shape[0], q.shape[1]
grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))
kv_group_num = q.shape[1] // k_buffer.shape[1]
if kv_group_num == 1:
num_warps = 4
else:
num_warps = 2
global cached_kernel_stage1
if cached_kernel_stage1:
cached_kernel_stage1(
grid,
num_warps,
q,
k_buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
att_out.stride(0),
)
return
_fwd_kernel_stage1[grid](
q,
k_buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
att_out.stride(0),
kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
def _token_softmax_reducev_fwd(
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
other_kv_index,
):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
grid = (batch, head, 1)
kv_group_num = logics.shape[0] // v_buffer.shape[1]
num_warps = 1
global cached_kernel_stage2
if cached_kernel_stage2:
cached_kernel_stage2(
grid,
num_warps,
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
other_kv_index,
)
return
_fwd_kernel_stage2[grid](
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
other_kv_index,
kv_group_num=kv_group_num,
BLOCK_DMODEL=v_buffer.shape[-1],
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=3,
)
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
def token_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
other_kv_index,
total_num_tokens,
att_m=None,
):
if att_m is None:
att_m = torch.empty(
(q.shape[-2], total_num_tokens), dtype=q.dtype, device="cuda"
)
_token_att_m_fwd(
q,
k_buffer,
att_m,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
)
_token_softmax_reducev_fwd(
att_m,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
other_kv_index,
)
import asyncio
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class DetokenizerManager:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
context = zmq.asyncio.Context(2)
self.recv_from_router = context.socket(zmq.PULL)
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
async def handle_loop(self):
while True:
recv_obj = await self.recv_from_router.recv_pyobj()
if isinstance(recv_obj, BatchTokenIDOut):
output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens per request
output_strs = self.tokenizer.batch_decode(
output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0],
)
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)):
if recv_obj.hit_stop_str[i] is not None:
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
if pos != -1:
output_strs[i] = output_strs[i][:pos]
if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
)
if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i]
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
recv_obj.rids,
output_strs,
recv_obj.meta_info,
recv_obj.finished,
)
)
else:
raise ValueError(f"Invalid object: {recv_obj}")
def start_detokenizer_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
try:
manager = DetokenizerManager(server_args, port_args)
except Exception as e:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.get_event_loop()
loop.run_until_complete(manager.handle_loop())
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from sglang.srt.sampling_params import SamplingParams
@dataclass
class GenerateReqInput:
text: Union[List[str], str]
image_data: Optional[Union[List[str], str]] = None
sampling_params: Union[List[Dict], Dict] = None
rid: Optional[Union[List[str], str]] = None
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
stream: bool = False
def post_init(self):
is_single = isinstance(self.text, str)
if is_single:
if self.sampling_params is None:
self.sampling_params = {}
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.return_normalized_logprob is None:
self.return_normalized_logprob = False
if self.normalized_logprob_start_len is None:
self.normalized_logprob_start_len = 0
else:
num = len(self.text)
if self.image_data is None:
self.image_data = [None] * num
elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num
if self.sampling_params is None:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
self.sampling_params = [self.sampling_params] * num
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
else:
assert isinstance(self.rid, list)
if self.return_normalized_logprob is None:
self.return_normalized_logprob = [False] * num
elif not isinstance(self.return_normalized_logprob, list):
self.return_normalized_logprob = [self.return_normalized_logprob] * num
if self.normalized_logprob_start_len is None:
self.normalized_logprob_start_len = [0] * num
elif not isinstance(self.normalized_logprob_start_len, list):
self.normalized_logprob_start_len = [
self.normalized_logprob_start_len
] * num
@dataclass
class TokenizedGenerateReqInput:
rid: str
input_ids: List[int]
pixel_values: List[float]
image_hash: int
sampling_params: SamplingParams
return_normalized_logprob: bool
normalized_logprob_start_len: int
stream: bool
@dataclass
class BatchTokenIDOut:
rids: List[str]
output_tokens: List[List[int]]
hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool]
meta_info: List[Dict]
finished: List[bool]
@dataclass
class BatchStrOut:
rids: List[str]
output_str: List[str]
meta_info: List[Dict]
finished: List[bool]
from dataclasses import dataclass
from typing import Any, List, Optional, Union
@dataclass
class CompletionRequest:
prompt: Union[str, List[Any]]
model: str = "default"
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 16
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
from enum import Enum, auto
from typing import List
import numpy as np
import torch
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
class ForwardMode(Enum):
PREFILL = auto()
EXTEND = auto()
DECODE = auto()
class FinishReason(Enum):
LENGTH = auto()
EOS_TOKEN = auto()
STOP_STR = auto()
class Req:
def __init__(self, rid):
self.rid = rid
self.input_ids = []
self.output_ids = []
self.pixel_values = None
self.image_offset = 0
self.sampling_params = None
self.return_normalized_logprob = False
self.normalized_logprob_start_len = 0
self.stream = False
self.tokenizer = None
self.finished = False
self.finish_reason = None
self.hit_stop_str = None
self.adjust_input_len = 0
self.prefix_indices = []
self.normalized_logprob = None
# for constrained decoding
self.regex_fsm = None
self.regex_fsm_state = None
def max_new_tokens(self):
return self.sampling_params.max_new_tokens
def check_finished(self):
if self.finished:
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished = True
self.finish_reason = FinishReason.LENGTH
return
if (
self.output_ids[-1] == self.tokenizer.eos_token_id
and self.sampling_params.ignore_eos == False
):
self.finished = True
self.finish_reason = FinishReason.EOS_TOKEN
return
if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
)
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str:
self.finished = True
self.finish_reason = FinishReason.STOP_STR
self.hit_stop_str = stop_str
return
def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
class Batch:
def __init__(
self,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: TokenToKVPool,
tree_cache: RadixCache,
):
self.reqs = reqs
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.tree_cache = tree_cache
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in reqs
)
def is_empty(self):
return len(self.reqs) == 0
def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda"
bs = len(self.reqs)
reqs = self.reqs
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
prefix_indices = [r.prefix_indices for r in reqs]
# Handle prefix
flatten_input_ids = []
extend_lens = []
prefix_lens = []
seq_lens = []
req_pool_indices = self.req_to_token_pool.alloc(bs)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
for i in range(bs):
flatten_input_ids.extend(input_ids[i])
extend_lens.append(len(input_ids[i]))
if len(prefix_indices[i]) == 0:
prefix_lens.append(0)
else:
prefix_lens.append(len(prefix_indices[i]))
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
: len(prefix_indices[i])
] = prefix_indices[i]
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
# Alloc mem
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
print("Prefill out of memory.")
self.tree_cache.pretty_print()
exit()
pt = 0
for i in range(bs):
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
] = out_cache_loc[pt : pt + extend_lens[i]]
pt += extend_lens[i]
# Handle logit bias
logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
for i in range(bs):
if reqs[i].sampling_params.dtype == "int":
logit_bias[i] = int_token_logit_bias
# Set fields
self.input_ids = torch.tensor(
flatten_input_ids, dtype=torch.int32, device=device
)
self.pixel_values = [r.pixel_values for r in reqs]
self.image_offsets = [
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
]
self.req_pool_indices = req_pool_indices
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
self.position_ids_offsets = position_ids_offsets
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
device=device,
).view(-1, 1)
self.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
).view(-1, 1)
self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
).view(-1, 1)
self.frequency_penalties = torch.tensor(
[r.sampling_params.frequency_penalty for r in reqs],
dtype=torch.float,
device=device,
)
self.presence_penalties = torch.tensor(
[r.sampling_params.presence_penalty for r in reqs],
dtype=torch.float,
device=device,
)
self.logit_bias = logit_bias
def update_for_decode(self, input_ids=None):
if input_ids is None:
input_ids = [
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
]
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
self.seq_lens.add_(1)
self.prefix_lens = None
# Alloc mem
bs = len(self.reqs)
alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
if alloc_res is None:
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
print("Decode out of memory.")
self.tree_cache.pretty_print()
exit()
self.out_cache_cont_start = None
self.out_cache_cont_end = None
else:
self.out_cache_loc = alloc_res[0]
self.out_cache_cont_start = alloc_res[1]
self.out_cache_cont_end = alloc_res[2]
self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens - 1
] = self.out_cache_loc
def filter_batch(self, unfinished_indices: List[int]):
self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
self.seq_lens = self.seq_lens[new_indices]
self.input_ids = None
self.req_pool_indices = self.req_pool_indices[new_indices]
self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
for item in [
"temperatures",
"top_ps",
"top_ks",
"frequency_penalties",
"presence_penalties",
"logit_bias",
]:
setattr(self, item, getattr(self, item)[new_indices])
def merge(self, other):
self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.prefix_lens = None
self.position_ids_offsets = torch.concat(
[self.position_ids_offsets, other.position_ids_offsets]
)
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
for item in [
"temperatures",
"top_ps",
"top_ks",
"frequency_penalties",
"presence_penalties",
"logit_bias",
]:
setattr(
self, item, torch.concat([getattr(self, item), getattr(other, item)])
)
def sample(self, logits: torch.Tensor):
# Post process logits
logits = logits.contiguous()
logits.div_(self.temperatures)
logits.add_(self.logit_bias)
has_regex = any(req.regex_fsm is not None for req in self.reqs)
if has_regex:
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
for i, req in enumerate(self.reqs):
if req.regex_fsm is not None:
allowed_mask.zero_()
allowed_mask[
req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
] = 1
logits[i].masked_fill_(~allowed_mask, float("-inf"))
# TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1)
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
sampled_index = torch.multinomial(probs_sort, num_samples=1)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
-1
)
batch_next_token_probs = torch.gather(
probs_sort, dim=1, index=sampled_index
).view(-1)
if has_regex:
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
for i, req in enumerate(self.reqs):
if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.next_state(
req.regex_fsm_state, batch_next_token_ids_cpu[i]
)
return batch_next_token_ids, batch_next_token_probs
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
return probs_sort, probs_idx
import asyncio
import logging
from typing import List, Tuple
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class RouterManager:
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
# Init communication
context = zmq.asyncio.Context(2)
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
# Init status
self.model_client = model_client
self.recv_reqs = []
async def loop_for_forward(self):
while True:
next_step_input = list(self.recv_reqs)
self.recv_reqs = []
out_pyobjs = await self.model_client.step(next_step_input)
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# await for a while to accept input requests
await asyncio.sleep(0.001)
async def loop_for_recv_requests(self):
while True:
recv_req = await self.recv_from_tokenizer.recv_pyobj()
self.recv_reqs.append(recv_req)
def start_router_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
model_client = ModelRpcClient(server_args, port_args)
router = RouterManager(model_client, port_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(router.loop_for_recv_requests())
loop.run_until_complete(router.loop_for_forward())
import asyncio
import logging
import multiprocessing
import time
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_exception_traceback,
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
)
logger = logging.getLogger("model_rpc")
class ModelRpcServer(rpyc.Service):
def exposed_init_model(
self,
tp_rank: int,
server_args: ServerArgs,
port_args: PortArgs,
):
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
# Copy arguments
self.model_mode = server_args.model_mode
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path, server_args.trust_remote_code
)
self.model_runner = ModelRunner(
self.model_config,
server_args.mem_fraction_static,
tp_rank,
server_args.tp_size,
port_args.nccl_port,
server_args.load_format,
server_args.trust_remote_code,
server_args.model_mode,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.eos_token_id = self.tokenizer.eos_token_id
self.max_total_num_token = self.model_runner.max_total_num_token
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len, self.max_total_num_token // 6
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
set_random_seed(server_args.random_seed)
logger.info(
f"Rank {self.tp_rank}: "
f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, "
f"model_mode={self.model_mode}"
)
# Init cache
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
self.scheduler = Scheduler(
self.schedule_heuristic,
self.max_num_running_seq,
self.max_prefill_num_token,
self.max_total_num_token,
self.tree_cache,
)
self.req_to_token_pool = self.model_runner.req_to_token_pool
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
# Init running status
self.forward_queue: List[Req] = []
self.running_batch: Batch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = 2
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(self.tokenizer)
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
recv_reqs = obtain(recv_reqs)
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
# Forward
self.forward_step()
except Exception:
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
# Return results
ret = self.out_pyobjs
self.out_pyobjs = []
return ret
@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_fill_batch()
if new_batch is not None:
# Run new fill batch
self.forward_fill_batch(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
else:
# Run decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(10):
self.forward_decode_batch(self.running_batch)
if self.running_batch.is_empty():
self.running_batch = None
break
if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20:
self.decode_forward_ct = 0
num_used = self.max_total_num_token - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
logger.info(
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_token:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid)
req.input_ids = recv_req.input_ids
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value
)
req.sampling_params = recv_req.sampling_params
req.return_normalized_logprob = recv_req.return_normalized_logprob
req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# init the regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
# Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens,
self.model_config.context_len - 1 - len(req.input_ids),
)
self.forward_queue.append(req)
def get_new_fill_batch(self):
if (
self.running_batch is not None
and len(self.running_batch.reqs) > self.max_num_running_seq
):
return None
for req in self.forward_queue:
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_normalized_logprob:
prefix_indices = prefix_indices[: req.normalized_logprob_start_len]
req.adjust_input_len = len(req.input_ids) - len(prefix_indices)
req.prefix_indices = prefix_indices
req.last_node = last_node
# Get priority queue
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
# Add requests if there is available space
can_run_list = []
new_batch_total_tokens = 0
new_batch_input_tokens = 0
new_batch_prefix_tokens = 0
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
new_ratio = self.scheduler.new_token_estimation_ratio()
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * new_ratio
for r in self.running_batch.reqs
]
)
for req in self.forward_queue:
if req.return_normalized_logprob:
# Need at least two tokens to compute normalized logprob
if req.adjust_input_len < 2:
delta = 2 - req.adjust_input_len
req.adjust_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.adjust_input_len == 0 and req.max_new_tokens() > 0:
# Need at least one token to compute logits
req.adjust_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
if (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and req.adjust_input_len + new_batch_input_tokens
< self.max_prefill_num_token
):
delta = self.tree_cache.inc_ref_counter(req.last_node)
available_size += delta
if not (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
):
delta = self.tree_cache.dec_ref_counter(req.last_node)
available_size += delta
else:
self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req)
new_batch_total_tokens += (
req.adjust_input_len + req.max_new_tokens()
)
new_batch_input_tokens += req.adjust_input_len
if len(can_run_list) == 0:
return None
if self.tp_rank == 0:
logger.info(
f"new fill batch. #seq: {len(can_run_list)}. "
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. "
f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
new_batch = Batch(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
)
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch
def forward_fill_batch(self, batch: Batch):
# Build batch tensors
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias)
if batch.extend_num_tokens != 0:
# Forward
logits, normalized_logprobs = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_normalized_logprob
)
# print("extend logits", logits)
if normalized_logprobs is not None:
normalized_logprobs = normalized_logprobs.cpu().tolist()
next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
normalized_logprobs = None
# Check finish condition
reqs = batch.reqs
for i in range(len(reqs)):
reqs[i].output_ids = [next_token_ids[i]]
reqs[i].check_finished()
if normalized_logprobs is not None:
reqs[i].normalized_logprob = normalized_logprobs[i]
self.handle_finished_requests(batch)
def forward_decode_batch(self, batch: Batch):
# Update batch tensors
self.decode_forward_ct += 1
batch.update_for_decode()
# Forward
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
# Check finish condition
reqs = batch.reqs
for i in range(len(reqs)):
reqs[i].output_ids.append(next_token_ids[i])
reqs[i].check_finished()
self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: Batch):
output_rids = []
output_tokens = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_meta_info = []
output_finished = []
finished_indices = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if req.finished:
finished_indices.append(i)
else:
unfinished_indices.append(i)
if req.finished or (
req.stream and self.decode_forward_ct % self.stream_interval == 0
):
output_rids.append(req.rid)
output_tokens.append(req.output_ids)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
meta_info = {
"prompt_tokens": len(req.input_ids),
"completion_tokens": len(req.output_ids),
}
if req.return_normalized_logprob:
meta_info["normalized_logprob"] = req.normalized_logprob
output_meta_info.append(meta_info)
output_finished.append(req.finished)
# Send to detokenizer
if output_rids:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
output_tokens,
output_hit_stop_str,
output_skip_special_tokens,
output_meta_info,
output_finished,
)
)
# Remove finished reqs
if finished_indices:
# Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
for i in finished_indices:
req = batch.reqs[i]
req_pool_idx = req_pool_indices_cpu[i]
token_ids = tuple(req.input_ids + req.output_ids)
seq_len = len(token_ids) - 1
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
self.token_to_kv_pool.free(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
# Update batch tensors
if unfinished_indices:
batch.filter_batch(unfinished_indices)
else:
batch.reqs = []
class ModelRpcClient:
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
tp_size = server_args.tp_size
if tp_size == 1:
# Init model
self.model_server = ModelRpcServer()
self.model_server.exposed_init_model(0, server_args, port_args)
# Wrap functions
def async_wrap(f):
async def _func(*args, **kwargs):
return f(*args, **kwargs)
return _func
self.step = async_wrap(self.model_server.exposed_step)
else:
with ThreadPoolExecutor(tp_size) as executor:
# Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports)
self.model_servers = [x[0] for x in rets]
self.procs = [x[1] for x in rets]
# Init model
def init_model(i):
return self.model_servers[i].init_model(i, server_args, port_args)
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
# Wrap functions
def async_wrap(func_name):
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
async def _func(*args, **kwargs):
tasks = [f(*args, **kwargs) for f in fs]
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
return obtain(tasks[0].value)
return _func
self.step = async_wrap("step")
def start_model_process(port):
def _init_service(port):
t = ThreadedServer(
ModelRpcServer(),
port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 600},
)
t.start()
proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start()
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
"localhost",
port,
config={"allow_pickle": True, "sync_request_timeout": 600},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
assert proc.is_alive()
return con.root, proc
from dataclasses import dataclass
from enum import Enum, auto
from typing import List
import numpy as np
import torch
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
from sglang.utils import get_available_gpu_memory
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
# for model_mode
global_model_mode: List[str] = []
@dataclass
class InputMetadata:
model_runner: "ModelRunner"
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
max_seq_len: int
req_pool_indices: torch.Tensor
start_loc: torch.Tensor
seq_lens: torch.Tensor
prefix_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
# for extend
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
max_extend_len: int = 0
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
return_normalized_logprob: bool = False
# for flashinfer
use_flashinfer: bool = False
qo_indptr: torch.Tensor = None
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_last_page_len: torch.Tensor = None
prefill_wrapper = None
decode_wrapper = None
def init_flashinfer_args(self, tp_size):
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
self.kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
self.req_pool_indices[i].item(), : self.seq_lens[i].item()
]
for i in range(self.batch_size)
],
dim=0,
).contiguous()
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
from flashinfer.ops import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
):
self.qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
self.prefill_wrapper.begin_forward(
self.qo_indptr,
self.batch_size,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
self.decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_last_page_len,
self.batch_size,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim,
1,
"NONE",
"float16",
)
def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0)
self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod
def create(
cls,
model_runner,
tp_size,
forward_mode,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
return_normalized_logprob=False,
):
batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
total_num_tokens = int(torch.sum(seq_lens))
max_seq_len = int(torch.max(seq_lens))
if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
seq_lens_np = seq_lens.cpu().numpy()
prefix_lens_np = prefix_lens.cpu().numpy()
position_ids_offsets_np = position_ids_offsets.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
prefix_lens_np[i] + position_ids_offsets_np[i],
seq_lens_np[i] + position_ids_offsets_np[i],
)
for i in range(batch_size)
],
axis=0,
),
device="cuda",
)
other_kv_index = None
ret = cls(
model_runner=model_runner,
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
max_seq_len=max_seq_len,
req_pool_indices=req_pool_indices,
start_loc=start_loc,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
return_normalized_logprob=return_normalized_logprob,
other_kv_index=other_kv_index,
)
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
ret.use_flashinfer = "flashinfer" in model_runner.model_mode
if ret.use_flashinfer:
ret.init_flashinfer_args(tp_size)
return ret
class ModelRunner:
def __init__(
self,
model_config,
mem_fraction_static,
tp_rank,
tp_size,
nccl_port,
load_format="auto",
trust_remote_code=True,
model_mode: List[str] = (),
):
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.tp_rank = tp_rank
self.tp_size = tp_size
self.nccl_port = nccl_port
self.load_format = load_format
self.trust_remote_code = trust_remote_code
self.model_mode = model_mode
global global_model_mode
global_model_mode = model_mode
# Init torch distributed
torch.cuda.set_device(self.tp_rank)
torch.distributed.init_process_group(
backend="nccl",
world_size=self.tp_size,
rank=self.tp_rank,
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)
# A small all_reduce for warmup.
if self.tp_size > 1:
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
self.load_model()
self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self):
"""See also vllm/model_executor/model_loader.py::get_model"""
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llava import LlavaLlamaForCausalLM
from sglang.srt.models.mixtral import MixtralForCausalLM
# Select model class
architectures = getattr(self.model_config.hf_config, "architectures", [])
model_class = None
for arch in architectures:
if arch == "LlamaForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "MistralForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "LlavaLlamaForCausalLM":
model_class = LlavaLlamaForCausalLM
break
if arch == "MixtralForCausalLM":
model_class = MixtralForCausalLM
break
if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}")
# Load weights
linear_method = None
with _set_default_torch_dtype(torch.float16):
with torch.device("cuda"):
hf_quant_config = getattr(
self.model_config.hf_config, "quantization_config", None
)
if hf_quant_config is not None:
# TODO: config quantization awq etc
quant_config = AWQConfig.from_config(hf_quant_config)
print(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method()
model = model_class(
config=self.model_config.hf_config, linear_method=linear_method
)
model.load_weights(
self.model_config.path,
cache_dir=None,
load_format=self.load_format,
revision=None,
)
self.model = model
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
head_dim = (
self.model_config.hidden_size // self.model_config.num_attention_heads
)
head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
max_num_token = int(rest_memory // cell_size)
return max_num_token
def init_memory_pool(self, total_gpu_memory):
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
self.req_to_token_pool = ReqToTokenPool(
int(self.max_total_num_token / self.model_config.context_len * 256),
self.model_config.context_len + 8,
)
self.token_to_kv_pool = TokenToKVPool(
self.max_total_num_token,
dtype=torch.float16,
head_num=self.model_config.num_key_value_heads // self.tp_size,
head_dim=self.model_config.hidden_size
// self.model_config.num_attention_heads,
layer_num=self.model_config.num_hidden_layers,
)
@torch.inference_mode()
def forward_prefill(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.PREFILL,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_extend(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_decode(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
0
]
@torch.inference_mode()
def forward_extend_multi_modal(
self,
input_ids,
pixel_values,
image_offsets,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob,
)
return self.model.forward(
input_ids,
input_metadata.positions,
input_metadata,
pixel_values,
image_offsets,
)
def forward(
self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False
):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
kwargs = {
"input_ids": batch.input_ids,
"pixel_values": batch.pixel_values,
"image_offsets": batch.image_offsets,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
}
kwargs["return_normalized_logprob"] = return_normalized_logprob
return self.forward_extend_multi_modal(**kwargs)
else:
kwargs = {
"input_ids": batch.input_ids,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
}
if forward_mode == ForwardMode.DECODE:
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs)
elif forward_mode == ForwardMode.EXTEND:
kwargs["return_normalized_logprob"] = return_normalized_logprob
return self.forward_extend(**kwargs)
elif forward_mode == ForwardMode.PREFILL:
kwargs["return_normalized_logprob"] = return_normalized_logprob
return self.forward_prefill(**kwargs)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")
import heapq
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Tuple
import torch
class TreeNode:
def __init__(self):
self.children = defaultdict(TreeNode)
self.parent = None
self.value = None
self.ref_counter = 0
self.last_access_time = time.time()
def __lt__(self, other):
return self.last_access_time < other.last_access_time
def match(key, seq):
i = 0
for k, w in zip(key, seq):
if k != w:
break
i += 1
return i
class RadixCache:
def __init__(self, disable=False):
self.root_node = TreeNode()
self.root_node.value = []
self.root_node.ref_counter = 1
self.evictable_size_ = 0
self.disable = disable
##### Public API #####
def match_prefix(self, key):
if self.disable:
return [], self.root_node
value = []
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
if value:
value = torch.concat(value)
return value, last_node[0]
def insert(self, key, value=None):
if self.disable:
return len(key)
if value is None:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def pretty_print(self):
self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}")
def total_size(self):
return self._total_size_helper(self.root_node)
def evict(self, num_tokens, evict_callback):
if self.disable:
raise RuntimeError()
leaves = self._collect_leaves()
heapq.heapify(leaves)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x == self.root_node:
break
if x.ref_counter > 0:
continue
num_evicted += evict_callback(x.value)
self._delete_leaf(x)
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
def inc_ref_counter(self, node):
delta = 0
while node != self.root_node:
if node.ref_counter == 0:
self.evictable_size_ -= len(node.value)
delta -= len(node.value)
node.ref_counter += 1
node = node.parent
return delta
def dec_ref_counter(self, node):
delta = 0
while node != self.root_node:
if node.ref_counter == 1:
self.evictable_size_ += len(node.value)
delta += len(node.value)
node.ref_counter -= 1
node = node.parent
return delta
def evictable_size(self):
return self.evictable_size_
##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node):
node.last_access_time = time.time()
for c_key, child in node.children.items():
prefix_len = match(c_key, key)
if prefix_len != 0:
if prefix_len == len(key) and prefix_len != len(c_key):
new_node = self._split_node(c_key, child, prefix_len)
value.append(new_node.value)
last_node[0] = new_node
else:
value.append(child.value[:prefix_len])
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
break
def _split_node(self, key, child, split_len):
# new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len:]: child}
new_node.parent = child.parent
new_node.ref_counter = child.ref_counter
new_node.value = child.value[:split_len]
child.parent = new_node
child.value = child.value[split_len:]
new_node.parent.children[key[:split_len]] = new_node
del new_node.parent.children[key]
return new_node
def _insert_helper(self, node, key, value):
node.last_access_time = time.time()
for c_key, child in node.children.items():
prefix_len = match(c_key, key)
if prefix_len == len(c_key):
if prefix_len == len(key):
return prefix_len
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)
if prefix_len:
new_node = self._split_node(c_key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.value = value
node.children[key] = new_node
self.evictable_size_ += len(value)
return 0
def _print_helper(self, node, indent):
for key, child in node.children.items():
print(" " * indent, len(key), key[:10], f"r={child.ref_counter}")
self._print_helper(child, indent=indent + 2)
def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.evictable_size_ -= len(k)
def _total_size_helper(self, node):
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
return x
def _collect_leaves(self):
ret_list = []
def dfs_(cur_node):
if len(cur_node.children) == 0:
ret_list.append(cur_node)
for x in cur_node.children.values():
dfs_(x)
dfs_(self.root_node)
return ret_list
if __name__ == "__main__":
tree = RadixCache(disable=False)
tree.insert("Hello")
tree.insert("Hello")
tree.insert("Hello_L.A.!")
# tree.insert("Hello_world! Happy")
# tree.insert("I love you!")
tree.pretty_print()
# print(tree.match_prefix("I love you! aha"))
# def evict_callback(x):
# print("evict", x)
# return len(x)
# tree.evict(5, evict_callback)
# tree.evict(10, evict_callback)
# tree.pretty_print()
import random
from collections import defaultdict
class Scheduler:
def __init__(
self,
schedule_heuristic,
max_running_seq,
max_prefill_num_token,
max_total_num_token,
tree_cache,
):
self.schedule_heuristic = schedule_heuristic
self.max_running_seq = max_running_seq
self.max_prefill_num_token = max_prefill_num_token
self.max_total_num_token = max_total_num_token
self.tree_cache = tree_cache
def new_token_estimation_ratio(self):
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
def get_priority_queue(self, forward_queue):
if self.schedule_heuristic == "lpm":
# longest prefix match
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
return forward_queue
elif self.schedule_heuristic == "random":
random.shuffle(forward_queue)
return forward_queue
elif self.schedule_heuristic == "fcfs":
return forward_queue
elif self.schedule_heuristic == "weight":
last_node_to_reqs = defaultdict(list)
for req in forward_queue:
last_node_to_reqs[req.last_node].append(req)
for node in last_node_to_reqs:
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
node_to_weight = defaultdict(int)
self._calc_weight_recursive(
self.tree_cache.root_node, last_node_to_reqs, node_to_weight
)
tmp_queue = []
self._get_weight_priority_recursive(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
)
assert len(tmp_queue) == len(forward_queue)
return tmp_queue
else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
node_to_weight[cur_node] = 1
if cur_node in last_node_to_reqs:
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
for child in cur_node.children.values():
self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child]
def _get_weight_priority_recursive(
self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
):
visit_list = [child for child in cur_node.children.values()]
visit_list.sort(key=lambda x: -node_to_wight[x])
# for node in visit_list:
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
for child in visit_list:
self._get_weight_priority_recursive(
child, node_to_wight, last_node_to_reqs, tmp_queue
)
tmp_queue.extend(last_node_to_reqs[cur_node])
import asyncio
import concurrent.futures
import dataclasses
import os
from typing import List
import numpy as np
import transformers
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.managers.io_struct import (
BatchStrOut,
GenerateReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass
class ReqState:
out_list: List
finished: bool
event: asyncio.Event
lock: asyncio.Lock
global global_processor
def init_global_processor(server_args: ServerArgs):
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
def get_pixel_values(image_data, processor=None):
try:
processor = processor or global_processor
image = load_image(image_data)
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
class TokenizerManager:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
self.model_path = server_args.model_path
self.hf_config = get_config(
self.model_path, trust_remote_code=server_args.trust_remote_code
)
self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor, initargs=(server_args,)
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.to_create_loop = True
self.rid_to_state = {} # Dict[str -> ReqState]
async def get_pixel_values(self, image_data):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, get_pixel_values, image_data
)
else:
return get_pixel_values(image_data, self.processor)
async def generate_request(self, obj: GenerateReqInput):
if self.to_create_loop:
await self.create_handle_loop()
is_single = isinstance(obj.text, str)
if is_single:
rid = obj.rid
input_ids = self.tokenizer.encode(obj.text)
sampling_params = SamplingParams(**obj.sampling_params)
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data is None:
pixel_values, image_hash = None, None
else:
pixel_values, image_hash = await self.get_pixel_values(obj.image_data)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob,
normalized_logprob_start_len=obj.normalized_logprob_start_len,
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
lock = asyncio.Lock()
event = asyncio.Event()
state = ReqState([], False, event, lock)
self.rid_to_state[rid] = state
while True:
await event.wait()
yield state.out_list[-1]
state.out_list = []
if state.finished:
del self.rid_to_state[rid]
break
event.clear()
else:
assert obj.stream is False
bs = len(obj.text)
for i in range(bs):
rid = obj.rid[i]
input_ids = self.tokenizer.encode(obj.text[i])
sampling_params = SamplingParams(**obj.sampling_params[i])
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data[i] is None:
pixel_values, image_hash = None, None
else:
pixel_values, image_hash = await self.get_pixel_values(
obj.image_data[i]
)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob[i],
normalized_logprob_start_len=obj.normalized_logprob_start_len[i],
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
lock = asyncio.Lock()
event = asyncio.Event()
state = ReqState([], False, event, lock)
self.rid_to_state[rid] = state
output_list = []
for i in range(bs):
rid = obj.rid[i]
state = self.rid_to_state[rid]
await state.event.wait()
output_list.append(state.out_list[-1])
assert state.finished
del self.rid_to_state[rid]
yield output_list
async def create_handle_loop(self):
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())
async def handle_loop(self):
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, BatchStrOut):
for i, rid in enumerate(recv_obj.rids):
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_str[i],
"meta_info": recv_obj.meta_info[i],
}
state = self.rid_to_state[rid]
state.out_list.append(out_dict)
state.finished = recv_obj.finished[i]
state.event.set()
else:
raise ValueError(f"Invalid object: {recv_obj}")
"""Memory pool."""
import logging
import torch
logger = logging.getLogger(__name__)
class ReqToTokenPool:
def __init__(self, size, max_context_len):
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self.can_use_mem_size = size
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
)
def alloc(self, need_size):
if need_size > self.can_use_mem_size:
return None
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
self.mem_state[select_index] = 0
self.can_use_mem_size -= need_size
return select_index.to(torch.int32)
def free(self, free_index):
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
else:
self.can_use_mem_size += free_index.shape[0]
self.mem_state[free_index] = 1
# if self.can_use_mem_size == len(self.mem_state):
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
def clear(self):
self.mem_state.fill_(1)
self.can_use_mem_size = len(self.mem_state)
class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
self.alloc_ct = 0
# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id):
return self.kv_data[layer_id][:, 0]
def get_value_buffer(self, layer_id):
return self.kv_data[layer_id][:, 1]
def alloc(self, need_size):
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
if select_index.shape[0] < need_size:
return None
self.add_refs(select_index)
return select_index.to(torch.int32)
def alloc_contiguous(self, need_size):
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
if empty_index.shape[0] < need_size:
return None
empty_size = len(empty_index)
loc_sum = (
empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)]
)
can_used_loc = empty_index[: empty_size - (need_size - 1)][
loc_sum == need_size - 1
]
if can_used_loc.shape[0] == 0:
return None
start_loc = can_used_loc[0].item()
select_index = torch.arange(start_loc, start_loc + need_size, device="cuda")
self.add_refs(select_index)
return select_index.to(torch.int32), start_loc, start_loc + need_size
def free(self, free_index):
return self.decrease_refs(free_index)
def used_size(self):
return len(torch.nonzero(self.mem_state).squeeze(1))
def available_size(self):
return torch.sum(self.mem_state == 0).item()
def add_refs(self, token_index: torch.Tensor):
self.alloc_ct += len(token_index)
self.mem_state[token_index] += 1
def decrease_refs(self, token_index: torch.Tensor):
self.alloc_ct -= len(token_index)
self.mem_state[token_index] -= 1
num_freed = torch.sum(self.mem_state[token_index] == 0)
# if self.alloc_ct == 0:
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
return num_freed
def clear(self):
self.mem_state.fill_(0)
self.alloc_ct = 0
import os
from typing import Optional, Union
import torch
from sglang.srt.hf_transformers_utils import get_config, get_context_length
class ModelConfig:
def __init__(
self,
path: str,
trust_remote_code: bool = True,
revision: Optional[str] = None,
) -> None:
self.path = path
self.trust_remote_code = trust_remote_code
self.revision = revision
self.hf_config = get_config(self.path, trust_remote_code, revision)
# Unify the config keys for hf_config
self.context_len = get_context_length(self.hf_config)
self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
self.num_key_value_heads = self.hf_config.num_key_value_heads
self.num_attention_heads = self.hf_config.num_attention_heads
self.hidden_size = self.hf_config.hidden_size
self.num_hidden_layers = self.hf_config.num_hidden_layers
self.vocab_size = self.hf_config.vocab_size
# Adapted from
# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
linear_method=linear_method,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, linear_method=linear_method
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config, i, linear_method)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
) -> torch.Tensor:
if not skip_embed:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_ids
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = LlamaModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
"""Inference-only LLaVa model compatible with HuggingFace weights."""
import json
import os
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM
from torch import nn
from transformers import CLIPImageProcessor, CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
class LlavaLlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlavaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.vision_tower = None
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = LlamaForCausalLM(config, linear_method)
def pad_input_ids(self, input_ids, pad_value):
pad_ids = pad_value * (
(self.image_feature_len + len(pad_value)) // len(pad_value)
)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[: self.image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size
# Embed text input
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
.cpu()
.numpy()
)
# FIXME: We need to substract the length of the system prompt
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
need_vision = need_vision & has_pixel
if need_vision.any():
pixel_values = torch.tensor(
np.array([pixel_values[i] for i in range(bs) if need_vision[i]]),
device=self.vision_tower.device,
)
image_outputs = self.vision_tower(
pixel_values, output_hidden_states=True
)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[
self.vision_feature_layer
]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)
image_features = self.multi_modal_projector(selected_image_feature)
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
input_embeds[
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
return self.language_model(
input_embeds, positions, input_metadata, skip_embed=True
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(
input_ids, positions, input_metadata, skip_embed=False
)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
if self.vision_feature_select_strategy == "patch":
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
# load mm_projector
# TODO: support TP?
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
# FIXME: why projector weights read two times?
if "projector" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# load language model
self.language_model.load_weights(
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward()
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
# Adapted from
# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Mixtral model."""
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from torch import nn
from transformers import MixtralConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
LinearMethodBase,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
class MixtralMLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
)
self.w2 = ReplicatedLinear(
self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method
)
self.w3 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}."
)
# Split experts equally between ranks
self.expert_indicies = np.array_split(
range(self.num_total_experts), self.tp_size
)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList(
[
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method,
)
if idx in self.expert_indicies
else None
for idx in range(self.num_total_experts)
]
)
self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, linear_method=None
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = selected_experts == expert_idx
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states)
class MixtralAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method,
)
self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
# config.num_hidden_layers=16
self.layers = nn.ModuleList(
[
MixtralDecoderLayer(config, i, linear_method=linear_method)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
) -> torch.Tensor:
if not skip_embed:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_ids
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision, fall_back_to_pt=False
):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if "block_sparse_moe.experts." in name and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment