Unverified Commit e1792cca authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Remove cached triton launcher (#656)

parent 1b7adbb5
...@@ -4,8 +4,6 @@ import torch ...@@ -4,8 +4,6 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
...@@ -119,9 +117,6 @@ def _fwd_kernel( ...@@ -119,9 +117,6 @@ def _fwd_kernel(
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
cached_kernel = None
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
if CUDA_CAPABILITY[0] >= 8: if CUDA_CAPABILITY[0] >= 8:
BLOCK = 128 BLOCK = 128
...@@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): ...@@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
)
return
_fwd_kernel[grid]( _fwd_kernel[grid](
q, q,
k, k,
...@@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): ...@@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
) )
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
...@@ -3,7 +3,6 @@ import triton ...@@ -3,7 +3,6 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
...@@ -172,9 +171,6 @@ def _fwd_kernel( ...@@ -172,9 +171,6 @@ def _fwd_kernel(
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
cached_kernel = None
def extend_attention_fwd( def extend_attention_fwd(
q_extend, q_extend,
k_extend, k_extend,
...@@ -222,40 +218,6 @@ def extend_attention_fwd( ...@@ -222,40 +218,6 @@ def extend_attention_fwd(
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
num_stages = 1 num_stages = 1
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
)
return
_fwd_kernel[grid]( _fwd_kernel[grid](
q_extend, q_extend,
k_extend, k_extend,
...@@ -290,7 +252,6 @@ def extend_attention_fwd( ...@@ -290,7 +252,6 @@ def extend_attention_fwd(
num_stages=num_stages, num_stages=num_stages,
logit_cap=logit_cap, logit_cap=logit_cap,
) )
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
def redundant_attention( def redundant_attention(
......
...@@ -6,7 +6,6 @@ import triton ...@@ -6,7 +6,6 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.server import global_server_args_dict from sglang.srt.server import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher
if global_server_args_dict.get("attention_reduce_in_fp32", False): if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32
...@@ -162,10 +161,6 @@ def _fwd_kernel_stage2( ...@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
tl.store(out_ptrs, acc) tl.store(out_ptrs, acc)
cached_kernel_stage1 = None
cached_kernel_stage2 = None
def _token_att_m_fwd( def _token_att_m_fwd(
q, q,
k_buffer, k_buffer,
...@@ -194,28 +189,6 @@ def _token_att_m_fwd( ...@@ -194,28 +189,6 @@ def _token_att_m_fwd(
else: else:
num_warps = 2 num_warps = 2
global cached_kernel_stage1
if cached_kernel_stage1:
cached_kernel_stage1(
grid,
num_warps,
q,
k_buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
att_out.stride(0),
)
return
_fwd_kernel_stage1[grid]( _fwd_kernel_stage1[grid](
q, q,
k_buffer, k_buffer,
...@@ -238,7 +211,6 @@ def _token_att_m_fwd( ...@@ -238,7 +211,6 @@ def _token_att_m_fwd(
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
) )
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
def _token_softmax_reducev_fwd( def _token_softmax_reducev_fwd(
...@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd( ...@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
num_warps = 1 num_warps = 1
global cached_kernel_stage2
if cached_kernel_stage2:
cached_kernel_stage2(
grid,
num_warps,
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
)
return
_fwd_kernel_stage2[grid]( _fwd_kernel_stage2[grid](
logics, logics,
v_buffer, v_buffer,
...@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd( ...@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
num_warps=num_warps, num_warps=num_warps,
num_stages=3, num_stages=3,
) )
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
def token_attention_fwd( def token_attention_fwd(
......
...@@ -51,6 +51,7 @@ from sglang.srt.utils import ( ...@@ -51,6 +51,7 @@ from sglang.srt.utils import (
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
set_ulimit,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -145,30 +146,6 @@ def _set_global_server_args(server_args: ServerArgs): ...@@ -145,30 +146,6 @@ def _set_global_server_args(server_args: ServerArgs):
} }
def _set_ulimit(target_soft_limit=65535):
import resource
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft >= target_soft_limit:
logger.info(
f"Current limits are already sufficient: soft={current_soft}, hard={current_hard}"
)
else:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
new_soft, new_hard = resource.getrlimit(resource_type)
logger.info(
f"Successfully set new limits: soft={new_soft}, hard={new_hard}"
)
except ValueError as e:
logger.warn(f"Failed to set new limits: {e}")
logger.info(
f"Limits remain unchanged: soft={current_soft}, hard={current_hard}"
)
def launch_server( def launch_server(
server_args: ServerArgs, server_args: ServerArgs,
model_overide_args: Optional[dict] = None, model_overide_args: Optional[dict] = None,
...@@ -186,7 +163,7 @@ def launch_server( ...@@ -186,7 +163,7 @@ def launch_server(
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0" os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0" os.environ["NCCL_NVLS_ENABLE"] = "0"
_set_ulimit() set_ulimit()
if server_args.show_time_cost: if server_args.show_time_cost:
enable_show_time_cost() enable_show_time_cost()
if server_args.disable_disk_cache: if server_args.disable_disk_cache:
......
...@@ -5,6 +5,7 @@ import fcntl ...@@ -5,6 +5,7 @@ import fcntl
import logging import logging
import os import os
import random import random
import resource
import socket import socket
import struct import struct
import time import time
...@@ -16,6 +17,7 @@ import numpy as np ...@@ -16,6 +17,7 @@ import numpy as np
import psutil import psutil
import requests import requests
import torch import torch
import torch.distributed as dist
import triton import triton
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
...@@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size): ...@@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
return logit_bias return logit_bias
def wrap_kernel_launcher(kernel):
"""A faster launcher for triton kernels."""
if int(triton.__version__.split(".")[0]) >= 3:
return None
gpu_id = torch.cuda.current_device()
kernels = kernel.cache[gpu_id].values()
kernel = next(iter(kernels))
# Different trition versions use different low-level names
if hasattr(kernel, "cu_function"):
kfunction = kernel.cu_function
else:
kfunction = kernel.function
if hasattr(kernel, "c_wrapper"):
run = kernel.c_wrapper
else:
run = kernel.run
add_cluster_dim = True
def ret_func(grid, num_warps, *args):
nonlocal add_cluster_dim
try:
if add_cluster_dim:
run(
grid[0],
grid[1],
grid[2],
num_warps,
1,
1,
1,
1,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
else:
run(
grid[0],
grid[1],
grid[2],
num_warps,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
except TypeError:
add_cluster_dim = not add_cluster_dim
ret_func(grid, num_warps, *args)
return ret_func
def is_multimodal_model(model): def is_multimodal_model(model):
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
...@@ -512,7 +449,6 @@ def get_ip_address(ifname): ...@@ -512,7 +449,6 @@ def get_ip_address(ifname):
def send_addrs_to_rank_0(model_port_args, server_args): def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1 assert server_args.node_rank != 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get( ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0") "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
...@@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args): ...@@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
def receive_addrs(model_port_args, server_args): def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1 assert server_args.node_rank == 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get( ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0") "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
...@@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args): ...@@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args):
dist.barrier() dist.barrier()
dist.destroy_process_group() dist.destroy_process_group()
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment