Unverified Commit 2cea6146 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve logging & add logit cap (#471)

parent 44c998fc
...@@ -30,7 +30,7 @@ if __name__ == "__main__": ...@@ -30,7 +30,7 @@ if __name__ == "__main__":
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
json={ json={
"text": f"{a}, ", "text": f"The capital of France is",
# "input_ids": [[2] * 256] * 196, # "input_ids": [[2] * 256] * 196,
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
......
...@@ -6,6 +6,9 @@ class FSMCache(BaseCache): ...@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
super().__init__(enable=enable) super().__init__(enable=enable)
if tokenizer_path.endswith(".json"):
return
from importlib.metadata import version from importlib.metadata import version
if version("outlines") >= "0.0.35": if version("outlines") >= "0.0.35":
......
...@@ -84,6 +84,9 @@ def get_tokenizer( ...@@ -84,6 +84,9 @@ def get_tokenizer(
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if tokenizer_name.endswith(".json"):
return TiktokenTokenizer(tokenizer_name)
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface."""
if is_multimodal_model(tokenizer_name): if is_multimodal_model(tokenizer_name):
processor = get_processor( processor = get_processor(
...@@ -170,3 +173,24 @@ def get_processor( ...@@ -170,3 +173,24 @@ def get_processor(
**kwargs, **kwargs,
) )
return processor return processor
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path)
self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_token
self.vocab_size = tokenizer.n_vocab
def encode(self, x):
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens, spaces_between_special_tokens):
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
\ No newline at end of file
...@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher ...@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q_Extend, Q_Extend,
...@@ -39,6 +45,7 @@ def _fwd_kernel( ...@@ -39,6 +45,7 @@ def _fwd_kernel(
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
): ):
cur_seq = tl.program_id(0) cur_seq = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -90,6 +97,10 @@ def _fwd_kernel( ...@@ -90,6 +97,10 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
qk *= sm_scale qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max) n_e_max = tl.maximum(tl.max(qk, 1), e_max)
...@@ -126,6 +137,10 @@ def _fwd_kernel( ...@@ -126,6 +137,10 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
qk *= sm_scale qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :] start_n + offs_n[None, :]
) )
...@@ -176,6 +191,7 @@ def extend_attention_fwd( ...@@ -176,6 +191,7 @@ def extend_attention_fwd(
b_seq_len_extend, b_seq_len_extend,
max_len_in_batch, max_len_in_batch,
max_len_extend, max_len_extend,
logit_cap=-1,
): ):
""" """
q_extend, k_extend, v_extend, o_extend: contiguous tensors q_extend, k_extend, v_extend, o_extend: contiguous tensors
...@@ -271,6 +287,7 @@ def extend_attention_fwd( ...@@ -271,6 +287,7 @@ def extend_attention_fwd(
BLOCK_N=BLOCK_N, BLOCK_N=BLOCK_N,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
logit_cap=logit_cap,
) )
cached_kernel = wrap_kernel_launcher(_fwd_kernel) cached_kernel = wrap_kernel_launcher(_fwd_kernel)
......
import torch import torch
import numpy as np
from torch import nn from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
...@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata ...@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id): def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim self.head_dim = head_dim
self.layer_id = layer_id self.layer_id = layer_id
self.logit_cap = logit_cap
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
from sglang.srt.managers.router.model_runner import global_server_args_dict from sglang.srt.managers.router.model_runner import global_server_args_dict
...@@ -38,6 +42,7 @@ class RadixAttention(nn.Module): ...@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
input_metadata.start_loc, input_metadata.start_loc,
input_metadata.seq_lens, input_metadata.seq_lens,
input_metadata.max_seq_len, input_metadata.max_seq_len,
self.logit_cap,
) )
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
...@@ -62,6 +67,7 @@ class RadixAttention(nn.Module): ...@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
input_metadata.extend_seq_lens, input_metadata.extend_seq_lens,
input_metadata.max_seq_len, input_metadata.max_seq_len,
input_metadata.max_extend_len, input_metadata.max_extend_len,
self.logit_cap,
) )
return o return o
...@@ -82,6 +88,7 @@ class RadixAttention(nn.Module): ...@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
input_metadata.max_seq_len, input_metadata.max_seq_len,
input_metadata.other_kv_index, input_metadata.other_kv_index,
input_metadata.total_num_tokens, input_metadata.total_num_tokens,
self.logit_cap,
) )
return o return o
......
...@@ -16,6 +16,12 @@ else: ...@@ -16,6 +16,12 @@ else:
REDUCE_TORCH_TYPE = torch.float16 REDUCE_TORCH_TYPE = torch.float16
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit @triton.jit
def _fwd_kernel_stage1( def _fwd_kernel_stage1(
Q, Q,
...@@ -35,6 +41,7 @@ def _fwd_kernel_stage1( ...@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
kv_group_num: tl.constexpr, kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -77,6 +84,10 @@ def _fwd_kernel_stage1( ...@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
).to(REDUCE_TRITON_TYPE) ).to(REDUCE_TRITON_TYPE)
att_value = tl.sum(q[None, :] * k, 1) att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale att_value *= sm_scale
if logit_cap > 0:
att_value = logit_cap * tanh(att_value / logit_cap)
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
...@@ -165,6 +176,7 @@ def _token_att_m_fwd( ...@@ -165,6 +176,7 @@ def _token_att_m_fwd(
B_Start_Loc, B_Start_Loc,
B_Seqlen, B_Seqlen,
max_len_in_batch, max_len_in_batch,
logit_cap,
): ):
BLOCK = 32 BLOCK = 32
# shape constraints # shape constraints
...@@ -223,6 +235,7 @@ def _token_att_m_fwd( ...@@ -223,6 +235,7 @@ def _token_att_m_fwd(
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
logit_cap=logit_cap,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
) )
...@@ -304,6 +317,7 @@ def token_attention_fwd( ...@@ -304,6 +317,7 @@ def token_attention_fwd(
max_len_in_batch, max_len_in_batch,
other_kv_index, other_kv_index,
total_num_tokens, total_num_tokens,
logit_cap=-1,
att_m=None, att_m=None,
): ):
if att_m is None: if att_m is None:
...@@ -320,6 +334,7 @@ def token_attention_fwd( ...@@ -320,6 +334,7 @@ def token_attention_fwd(
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
max_len_in_batch, max_len_in_batch,
logit_cap,
) )
_token_softmax_reducev_fwd( _token_softmax_reducev_fwd(
att_m, att_m,
......
import asyncio import asyncio
import inspect
import uvloop import uvloop
import zmq import zmq
...@@ -7,7 +8,7 @@ import zmq.asyncio ...@@ -7,7 +8,7 @@ import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback, graceful_registry
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -85,6 +86,8 @@ def start_detokenizer_process( ...@@ -85,6 +86,8 @@ def start_detokenizer_process(
port_args: PortArgs, port_args: PortArgs,
pipe_writer, pipe_writer,
): ):
graceful_registry(inspect.currentframe().f_code.co_name)
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
except Exception as e: except Exception as e:
......
...@@ -106,8 +106,7 @@ class ModelRpcServer: ...@@ -106,8 +106,7 @@ class ModelRpcServer:
set_random_seed(server_args.random_seed) set_random_seed(server_args.random_seed)
# Print info # Print info
logger.info( logger.info(f"[rank={self.tp_rank}] "
f"Rank {self.tp_rank}: "
f"max_total_num_token={self.max_total_num_token}, " f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, " f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, " f"context_len={self.model_config.context_len}, "
...@@ -752,7 +751,7 @@ def _init_service(port): ...@@ -752,7 +751,7 @@ def _init_service(port):
protocol_config={ protocol_config={
"allow_public_attrs": True, "allow_public_attrs": True,
"allow_pickle": True, "allow_pickle": True,
"sync_request_timeout": 1800, "sync_request_timeout": 3600,
}, },
) )
t.start() t.start()
...@@ -772,7 +771,7 @@ def start_model_process(port): ...@@ -772,7 +771,7 @@ def start_model_process(port):
config={ config={
"allow_public_attrs": True, "allow_public_attrs": True,
"allow_pickle": True, "allow_pickle": True,
"sync_request_timeout": 1800, "sync_request_timeout": 3600,
}, },
) )
break break
......
...@@ -235,8 +235,8 @@ class ModelRunner: ...@@ -235,8 +235,8 @@ class ModelRunner:
} }
# Init torch distributed # Init torch distributed
logger.debug("Init torch begin.")
torch.cuda.set_device(self.tp_rank) torch.cuda.set_device(self.tp_rank)
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend="nccl",
world_size=self.tp_size, world_size=self.tp_size,
...@@ -244,20 +244,22 @@ class ModelRunner: ...@@ -244,20 +244,22 @@ class ModelRunner:
init_method=f"tcp://127.0.0.1:{self.nccl_port}", init_method=f"tcp://127.0.0.1:{self.nccl_port}",
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
logger.debug("Init torch end.") logger.info(f"[rank={self.tp_rank}] Init torch end.")
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
if total_local_gpu_memory < total_gpu_memory * 0.9:
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
total_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.load_model() self.load_model()
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.init_memory_pool(total_gpu_memory) self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config) self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self): def load_model(self):
logger.info(f"Rank {self.tp_rank}: load weight begin.") logger.info(f"[rank={self.tp_rank}] Load weight begin.")
device_config = DeviceConfig() device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format) load_config = LoadConfig(load_format=self.server_args.load_format)
...@@ -283,19 +285,19 @@ class ModelRunner: ...@@ -283,19 +285,19 @@ class ModelRunner:
parallel_config=None, parallel_config=None,
scheduler_config=None, scheduler_config=None,
) )
logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}") logger.info(f"[rank={self.tp_rank}] Load weight end. "
f"Type={type(self.model).__name__}. "
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
head_num = self.model_config.num_key_value_heads // self.tp_size head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
rest_memory = available_gpu_memory - total_gpu_memory * ( rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static 1 - self.mem_fraction_static
) )
max_num_token = int(rest_memory // cell_size) max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token return max_num_token
def init_memory_pool(self, total_gpu_memory): def init_memory_pool(self, total_gpu_memory):
......
...@@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
time.sleep(0.5) time.sleep(0.5)
try: try:
requests.get(url + "/get_model_info", timeout=5, headers=headers) requests.get(url + "/get_model_info", timeout=5, headers=headers)
success = True # Set flag to True if request succeeds
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
pass pass
...@@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
res = requests.post( res = requests.post(
url + "/generate", url + "/generate",
json={ json={
"text": "Say this is a warmup request.", "text": "The capital city of France is",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
......
...@@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0): ...@@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
return wrapper return wrapper
def get_available_gpu_memory(gpu_id, distributed=True): def get_available_gpu_memory(gpu_id, distributed=False):
""" """
Get available memory for cuda:gpu_id device. Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs. When distributed is True, the available memory is the minimum available memory of all GPUs.
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
import base64 import base64
import json import json
import os import logging
import signal
import sys import sys
import threading import threading
import traceback import traceback
...@@ -15,6 +16,9 @@ import numpy as np ...@@ -15,6 +16,9 @@ import numpy as np
import requests import requests
logger = logging.getLogger(__name__)
def get_exception_traceback(): def get_exception_traceback():
etype, value, tb = sys.exc_info() etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb)) err_str = "".join(traceback.format_exception(etype, value, tb))
...@@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None): ...@@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
raise RuntimeError() raise RuntimeError()
return ret_value[0] return ret_value[0]
def graceful_registry(sub_module_name):
def graceful_shutdown(signum, frame):
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
if signum == signal.SIGTERM:
logger.info(f"{sub_module_name} recive sigterm")
signal.signal(signal.SIGTERM, graceful_shutdown)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment