Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e1792cca
Unverified
Commit
e1792cca
authored
Jul 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 18, 2024
Browse files
Remove cached triton launcher (#656)
parent
1b7adbb5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
15 additions
and
210 deletions
+15
-210
python/sglang/srt/layers/context_flashattention_nopad.py
python/sglang/srt/layers/context_flashattention_nopad.py
+0
-29
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+0
-39
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+0
-50
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-25
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+13
-67
No files found.
python/sglang/srt/layers/context_flashattention_nopad.py
View file @
e1792cca
...
@@ -4,8 +4,6 @@ import torch
...
@@ -4,8 +4,6 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.utils
import
wrap_kernel_launcher
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
...
@@ -119,9 +117,6 @@ def _fwd_kernel(
...
@@ -119,9 +117,6 @@ def _fwd_kernel(
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
)
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
)
cached_kernel
=
None
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
):
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
):
if
CUDA_CAPABILITY
[
0
]
>=
8
:
if
CUDA_CAPABILITY
[
0
]
>=
8
:
BLOCK
=
128
BLOCK
=
128
...
@@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
...
@@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
num_warps
=
4
if
Lk
<=
64
else
8
num_warps
=
4
if
Lk
<=
64
else
8
global
cached_kernel
if
cached_kernel
:
cached_kernel
(
grid
,
num_warps
,
q
,
k
,
v
,
sm_scale
,
b_start_loc
,
b_seq_len
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
)
return
_fwd_kernel
[
grid
](
_fwd_kernel
[
grid
](
q
,
q
,
k
,
k
,
...
@@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
...
@@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
)
)
cached_kernel
=
wrap_kernel_launcher
(
_fwd_kernel
)
python/sglang/srt/layers/extend_attention.py
View file @
e1792cca
...
@@ -3,7 +3,6 @@ import triton
...
@@ -3,7 +3,6 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
from
sglang.srt.utils
import
wrap_kernel_launcher
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
...
@@ -172,9 +171,6 @@ def _fwd_kernel(
...
@@ -172,9 +171,6 @@ def _fwd_kernel(
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
])
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
])
cached_kernel
=
None
def
extend_attention_fwd
(
def
extend_attention_fwd
(
q_extend
,
q_extend
,
k_extend
,
k_extend
,
...
@@ -222,40 +218,6 @@ def extend_attention_fwd(
...
@@ -222,40 +218,6 @@ def extend_attention_fwd(
num_warps
=
4
if
Lk
<=
64
else
8
num_warps
=
4
if
Lk
<=
64
else
8
num_stages
=
1
num_stages
=
1
global
cached_kernel
if
cached_kernel
:
cached_kernel
(
grid
,
num_warps
,
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_start_loc_extend
,
b_seq_len_extend
,
sm_scale
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
1
),
k_extend
.
stride
(
0
),
k_extend
.
stride
(
1
),
v_extend
.
stride
(
0
),
v_extend
.
stride
(
1
),
o_extend
.
stride
(
0
),
o_extend
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
)
return
_fwd_kernel
[
grid
](
_fwd_kernel
[
grid
](
q_extend
,
q_extend
,
k_extend
,
k_extend
,
...
@@ -290,7 +252,6 @@ def extend_attention_fwd(
...
@@ -290,7 +252,6 @@ def extend_attention_fwd(
num_stages
=
num_stages
,
num_stages
=
num_stages
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
)
)
cached_kernel
=
wrap_kernel_launcher
(
_fwd_kernel
)
def
redundant_attention
(
def
redundant_attention
(
...
...
python/sglang/srt/layers/token_attention.py
View file @
e1792cca
...
@@ -6,7 +6,6 @@ import triton
...
@@ -6,7 +6,6 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.server
import
global_server_args_dict
from
sglang.srt.server
import
global_server_args_dict
from
sglang.srt.utils
import
wrap_kernel_launcher
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TRITON_TYPE
=
tl
.
float32
...
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
...
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
tl
.
store
(
out_ptrs
,
acc
)
tl
.
store
(
out_ptrs
,
acc
)
cached_kernel_stage1
=
None
cached_kernel_stage2
=
None
def
_token_att_m_fwd
(
def
_token_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
...
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
else
:
else
:
num_warps
=
2
num_warps
=
2
global
cached_kernel_stage1
if
cached_kernel_stage1
:
cached_kernel_stage1
(
grid
,
num_warps
,
q
,
k_buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
)
return
_fwd_kernel_stage1
[
grid
](
_fwd_kernel_stage1
[
grid
](
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
...
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
)
)
cached_kernel_stage1
=
wrap_kernel_launcher
(
_fwd_kernel_stage1
)
def
_token_softmax_reducev_fwd
(
def
_token_softmax_reducev_fwd
(
...
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
...
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
num_warps
=
1
num_warps
=
1
global
cached_kernel_stage2
if
cached_kernel_stage2
:
cached_kernel_stage2
(
grid
,
num_warps
,
logics
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logics
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
)
return
_fwd_kernel_stage2
[
grid
](
_fwd_kernel_stage2
[
grid
](
logics
,
logics
,
v_buffer
,
v_buffer
,
...
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
...
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
3
,
num_stages
=
3
,
)
)
cached_kernel_stage2
=
wrap_kernel_launcher
(
_fwd_kernel_stage2
)
def
token_attention_fwd
(
def
token_attention_fwd
(
...
...
python/sglang/srt/server.py
View file @
e1792cca
...
@@ -51,6 +51,7 @@ from sglang.srt.utils import (
...
@@ -51,6 +51,7 @@ from sglang.srt.utils import (
allocate_init_ports
,
allocate_init_ports
,
assert_pkg_version
,
assert_pkg_version
,
enable_show_time_cost
,
enable_show_time_cost
,
set_ulimit
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -145,30 +146,6 @@ def _set_global_server_args(server_args: ServerArgs):
...
@@ -145,30 +146,6 @@ def _set_global_server_args(server_args: ServerArgs):
}
}
def
_set_ulimit
(
target_soft_limit
=
65535
):
import
resource
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
if
current_soft
>=
target_soft_limit
:
logger
.
info
(
f
"Current limits are already sufficient: soft=
{
current_soft
}
, hard=
{
current_hard
}
"
)
else
:
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
new_soft
,
new_hard
=
resource
.
getrlimit
(
resource_type
)
logger
.
info
(
f
"Successfully set new limits: soft=
{
new_soft
}
, hard=
{
new_hard
}
"
)
except
ValueError
as
e
:
logger
.
warn
(
f
"Failed to set new limits:
{
e
}
"
)
logger
.
info
(
f
"Limits remain unchanged: soft=
{
current_soft
}
, hard=
{
current_hard
}
"
)
def
launch_server
(
def
launch_server
(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
model_overide_args
:
Optional
[
dict
]
=
None
,
model_overide_args
:
Optional
[
dict
]
=
None
,
...
@@ -186,7 +163,7 @@ def launch_server(
...
@@ -186,7 +163,7 @@ def launch_server(
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
os
.
environ
[
"NCCL_CUMEM_ENABLE"
]
=
"0"
os
.
environ
[
"NCCL_CUMEM_ENABLE"
]
=
"0"
os
.
environ
[
"NCCL_NVLS_ENABLE"
]
=
"0"
os
.
environ
[
"NCCL_NVLS_ENABLE"
]
=
"0"
_
set_ulimit
()
set_ulimit
()
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
enable_show_time_cost
()
if
server_args
.
disable_disk_cache
:
if
server_args
.
disable_disk_cache
:
...
...
python/sglang/srt/utils.py
View file @
e1792cca
...
@@ -5,6 +5,7 @@ import fcntl
...
@@ -5,6 +5,7 @@ import fcntl
import
logging
import
logging
import
os
import
os
import
random
import
random
import
resource
import
socket
import
socket
import
struct
import
struct
import
time
import
time
...
@@ -16,6 +17,7 @@ import numpy as np
...
@@ -16,6 +17,7 @@ import numpy as np
import
psutil
import
psutil
import
requests
import
requests
import
torch
import
torch
import
torch.distributed
as
dist
import
triton
import
triton
from
fastapi.responses
import
JSONResponse
from
fastapi.responses
import
JSONResponse
from
packaging
import
version
as
pkg_version
from
packaging
import
version
as
pkg_version
...
@@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
...
@@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
return
logit_bias
return
logit_bias
def
wrap_kernel_launcher
(
kernel
):
"""A faster launcher for triton kernels."""
if
int
(
triton
.
__version__
.
split
(
"."
)[
0
])
>=
3
:
return
None
gpu_id
=
torch
.
cuda
.
current_device
()
kernels
=
kernel
.
cache
[
gpu_id
].
values
()
kernel
=
next
(
iter
(
kernels
))
# Different trition versions use different low-level names
if
hasattr
(
kernel
,
"cu_function"
):
kfunction
=
kernel
.
cu_function
else
:
kfunction
=
kernel
.
function
if
hasattr
(
kernel
,
"c_wrapper"
):
run
=
kernel
.
c_wrapper
else
:
run
=
kernel
.
run
add_cluster_dim
=
True
def
ret_func
(
grid
,
num_warps
,
*
args
):
nonlocal
add_cluster_dim
try
:
if
add_cluster_dim
:
run
(
grid
[
0
],
grid
[
1
],
grid
[
2
],
num_warps
,
1
,
1
,
1
,
1
,
kernel
.
shared
,
0
,
kfunction
,
None
,
None
,
kernel
,
*
args
,
)
else
:
run
(
grid
[
0
],
grid
[
1
],
grid
[
2
],
num_warps
,
kernel
.
shared
,
0
,
kfunction
,
None
,
None
,
kernel
,
*
args
,
)
except
TypeError
:
add_cluster_dim
=
not
add_cluster_dim
ret_func
(
grid
,
num_warps
,
*
args
)
return
ret_func
def
is_multimodal_model
(
model
):
def
is_multimodal_model
(
model
):
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
...
@@ -512,7 +449,6 @@ def get_ip_address(ifname):
...
@@ -512,7 +449,6 @@ def get_ip_address(ifname):
def
send_addrs_to_rank_0
(
model_port_args
,
server_args
):
def
send_addrs_to_rank_0
(
model_port_args
,
server_args
):
assert
server_args
.
node_rank
!=
0
and
server_args
.
dp_size
==
1
assert
server_args
.
node_rank
!=
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
)
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
)
...
@@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
...
@@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
def
receive_addrs
(
model_port_args
,
server_args
):
def
receive_addrs
(
model_port_args
,
server_args
):
assert
server_args
.
node_rank
==
0
and
server_args
.
dp_size
==
1
assert
server_args
.
node_rank
==
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
)
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
)
...
@@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args):
...
@@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args):
dist
.
barrier
()
dist
.
barrier
()
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
set_ulimit
(
target_soft_limit
=
65535
):
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
if
current_soft
<
target_soft_limit
:
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
except
ValueError
as
e
:
logger
.
warn
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment