"src/graph/sampling/randomwalks/randomwalks_cpu.h" did not exist on "a9dabcc769554bd3c8daff7d6b76d3104910b445"
Unverified Commit 2cea6146 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve logging & add logit cap (#471)

parent 44c998fc
......@@ -30,7 +30,7 @@ if __name__ == "__main__":
response = requests.post(
url + "/generate",
json={
"text": f"{a}, ",
"text": f"The capital of France is",
# "input_ids": [[2] * 256] * 196,
"sampling_params": {
"temperature": 0,
......
......@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
super().__init__(enable=enable)
if tokenizer_path.endswith(".json"):
return
from importlib.metadata import version
if version("outlines") >= "0.0.35":
......
......@@ -84,6 +84,9 @@ def get_tokenizer(
tokenizer_revision: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if tokenizer_name.endswith(".json"):
return TiktokenTokenizer(tokenizer_name)
"""Gets a tokenizer for the given model name via Huggingface."""
if is_multimodal_model(tokenizer_name):
processor = get_processor(
......@@ -170,3 +173,24 @@ def get_processor(
**kwargs,
)
return processor
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path)
self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_token
self.vocab_size = tokenizer.n_vocab
def encode(self, x):
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens, spaces_between_special_tokens):
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
\ No newline at end of file
......@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _fwd_kernel(
Q_Extend,
......@@ -39,6 +45,7 @@ def _fwd_kernel(
BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -90,6 +97,10 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
......@@ -126,6 +137,10 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
......@@ -176,6 +191,7 @@ def extend_attention_fwd(
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
logit_cap=-1,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
......@@ -271,6 +287,7 @@ def extend_attention_fwd(
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
logit_cap=logit_cap,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
......
import torch
import numpy as np
from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
......@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
class RadixAttention(nn.Module):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.layer_id = layer_id
self.logit_cap = logit_cap
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
from sglang.srt.managers.router.model_runner import global_server_args_dict
......@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
self.logit_cap,
)
self.store_kv_cache(k, v, input_metadata)
......@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
input_metadata.extend_seq_lens,
input_metadata.max_seq_len,
input_metadata.max_extend_len,
self.logit_cap,
)
return o
......@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
input_metadata.max_seq_len,
input_metadata.other_kv_index,
input_metadata.total_num_tokens,
self.logit_cap,
)
return o
......
......@@ -16,6 +16,12 @@ else:
REDUCE_TORCH_TYPE = torch.float16
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _fwd_kernel_stage1(
Q,
......@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
).to(REDUCE_TRITON_TYPE)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
if logit_cap > 0:
att_value = logit_cap * tanh(att_value / logit_cap)
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
......@@ -165,6 +176,7 @@ def _token_att_m_fwd(
B_Start_Loc,
B_Seqlen,
max_len_in_batch,
logit_cap,
):
BLOCK = 32
# shape constraints
......@@ -223,6 +235,7 @@ def _token_att_m_fwd(
kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=1,
)
......@@ -304,6 +317,7 @@ def token_attention_fwd(
max_len_in_batch,
other_kv_index,
total_num_tokens,
logit_cap=-1,
att_m=None,
):
if att_m is None:
......@@ -320,6 +334,7 @@ def token_attention_fwd(
b_start_loc,
b_seq_len,
max_len_in_batch,
logit_cap,
)
_token_softmax_reducev_fwd(
att_m,
......
import asyncio
import inspect
import uvloop
import zmq
......@@ -7,7 +8,7 @@ import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback
from sglang.utils import get_exception_traceback, graceful_registry
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -85,6 +86,8 @@ def start_detokenizer_process(
port_args: PortArgs,
pipe_writer,
):
graceful_registry(inspect.currentframe().f_code.co_name)
try:
manager = DetokenizerManager(server_args, port_args)
except Exception as e:
......
......@@ -106,8 +106,7 @@ class ModelRpcServer:
set_random_seed(server_args.random_seed)
# Print info
logger.info(
f"Rank {self.tp_rank}: "
logger.info(f"[rank={self.tp_rank}] "
f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, "
......@@ -752,7 +751,7 @@ def _init_service(port):
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 1800,
"sync_request_timeout": 3600,
},
)
t.start()
......@@ -772,7 +771,7 @@ def start_model_process(port):
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 1800,
"sync_request_timeout": 3600,
},
)
break
......
......@@ -235,8 +235,8 @@ class ModelRunner:
}
# Init torch distributed
logger.debug("Init torch begin.")
torch.cuda.set_device(self.tp_rank)
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
torch.distributed.init_process_group(
backend="nccl",
world_size=self.tp_size,
......@@ -244,20 +244,22 @@ class ModelRunner:
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
logger.debug("Init torch end.")
logger.info(f"[rank={self.tp_rank}] Init torch end.")
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
if total_local_gpu_memory < total_gpu_memory * 0.9:
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
total_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.load_model()
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self):
logger.info(f"Rank {self.tp_rank}: load weight begin.")
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format)
......@@ -283,19 +285,19 @@ class ModelRunner:
parallel_config=None,
scheduler_config=None,
)
logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}")
logger.info(f"[rank={self.tp_rank}] Load weight end. "
f"Type={type(self.model).__name__}. "
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
head_dim = self.model_config.head_dim
head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
max_num_token = int(rest_memory // cell_size)
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
def init_memory_pool(self, total_gpu_memory):
......
......@@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5, headers=headers)
success = True # Set flag to True if request succeeds
break
except requests.exceptions.RequestException as e:
pass
......@@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
res = requests.post(
url + "/generate",
json={
"text": "Say this is a warmup request.",
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
......
......@@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
return wrapper
def get_available_gpu_memory(gpu_id, distributed=True):
def get_available_gpu_memory(gpu_id, distributed=False):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
......
......@@ -2,7 +2,8 @@
import base64
import json
import os
import logging
import signal
import sys
import threading
import traceback
......@@ -15,6 +16,9 @@ import numpy as np
import requests
logger = logging.getLogger(__name__)
def get_exception_traceback():
etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb))
......@@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
raise RuntimeError()
return ret_value[0]
def graceful_registry(sub_module_name):
def graceful_shutdown(signum, frame):
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
if signum == signal.SIGTERM:
logger.info(f"{sub_module_name} recive sigterm")
signal.signal(signal.SIGTERM, graceful_shutdown)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment