Unverified Commit e1792cca authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Remove cached triton launcher (#656)

parent 1b7adbb5
......@@ -4,8 +4,6 @@ import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
......@@ -119,9 +117,6 @@ def _fwd_kernel(
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
cached_kernel = None
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
if CUDA_CAPABILITY[0] >= 8:
BLOCK = 128
......@@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
)
return
_fwd_kernel[grid](
q,
k,
......@@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
num_warps=num_warps,
num_stages=1,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
......@@ -3,7 +3,6 @@ import triton
import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
......@@ -172,9 +171,6 @@ def _fwd_kernel(
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
cached_kernel = None
def extend_attention_fwd(
q_extend,
k_extend,
......@@ -222,40 +218,6 @@ def extend_attention_fwd(
num_warps = 4 if Lk <= 64 else 8
num_stages = 1
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
)
return
_fwd_kernel[grid](
q_extend,
k_extend,
......@@ -290,7 +252,6 @@ def extend_attention_fwd(
num_stages=num_stages,
logit_cap=logit_cap,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
def redundant_attention(
......
......@@ -6,7 +6,6 @@ import triton
import triton.language as tl
from sglang.srt.server import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher
if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
......@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
tl.store(out_ptrs, acc)
cached_kernel_stage1 = None
cached_kernel_stage2 = None
def _token_att_m_fwd(
q,
k_buffer,
......@@ -194,28 +189,6 @@ def _token_att_m_fwd(
else:
num_warps = 2
global cached_kernel_stage1
if cached_kernel_stage1:
cached_kernel_stage1(
grid,
num_warps,
q,
k_buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
att_out.stride(0),
)
return
_fwd_kernel_stage1[grid](
q,
k_buffer,
......@@ -238,7 +211,6 @@ def _token_att_m_fwd(
num_warps=num_warps,
num_stages=1,
)
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
def _token_softmax_reducev_fwd(
......@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
num_warps = 1
global cached_kernel_stage2
if cached_kernel_stage2:
cached_kernel_stage2(
grid,
num_warps,
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
)
return
_fwd_kernel_stage2[grid](
logics,
v_buffer,
......@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
num_warps=num_warps,
num_stages=3,
)
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
def token_attention_fwd(
......
......@@ -51,6 +51,7 @@ from sglang.srt.utils import (
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
set_ulimit,
)
from sglang.utils import get_exception_traceback
......@@ -145,30 +146,6 @@ def _set_global_server_args(server_args: ServerArgs):
}
def _set_ulimit(target_soft_limit=65535):
import resource
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft >= target_soft_limit:
logger.info(
f"Current limits are already sufficient: soft={current_soft}, hard={current_hard}"
)
else:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
new_soft, new_hard = resource.getrlimit(resource_type)
logger.info(
f"Successfully set new limits: soft={new_soft}, hard={new_hard}"
)
except ValueError as e:
logger.warn(f"Failed to set new limits: {e}")
logger.info(
f"Limits remain unchanged: soft={current_soft}, hard={current_hard}"
)
def launch_server(
server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
......@@ -186,7 +163,7 @@ def launch_server(
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
_set_ulimit()
set_ulimit()
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
......
......@@ -5,6 +5,7 @@ import fcntl
import logging
import os
import random
import resource
import socket
import struct
import time
......@@ -16,6 +17,7 @@ import numpy as np
import psutil
import requests
import torch
import torch.distributed as dist
import triton
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
......@@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
return logit_bias
def wrap_kernel_launcher(kernel):
"""A faster launcher for triton kernels."""
if int(triton.__version__.split(".")[0]) >= 3:
return None
gpu_id = torch.cuda.current_device()
kernels = kernel.cache[gpu_id].values()
kernel = next(iter(kernels))
# Different trition versions use different low-level names
if hasattr(kernel, "cu_function"):
kfunction = kernel.cu_function
else:
kfunction = kernel.function
if hasattr(kernel, "c_wrapper"):
run = kernel.c_wrapper
else:
run = kernel.run
add_cluster_dim = True
def ret_func(grid, num_warps, *args):
nonlocal add_cluster_dim
try:
if add_cluster_dim:
run(
grid[0],
grid[1],
grid[2],
num_warps,
1,
1,
1,
1,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
else:
run(
grid[0],
grid[1],
grid[2],
num_warps,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
except TypeError:
add_cluster_dim = not add_cluster_dim
ret_func(grid, num_warps, *args)
return ret_func
def is_multimodal_model(model):
from sglang.srt.model_config import ModelConfig
......@@ -512,7 +449,6 @@ def get_ip_address(ifname):
def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
......@@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
......@@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args):
dist.barrier()
dist.destroy_process_group()
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment