Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
99ec439d
Unverified
Commit
99ec439d
authored
Sep 30, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 30, 2024
Browse files
Organize Attention Backends (#1547)
parent
0f4fb19b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
229 additions
and
205 deletions
+229
-205
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+49
-0
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-195
python/sglang/srt/layers/attention/flashinfer_utils.py
python/sglang/srt/layers/attention/flashinfer_utils.py
+0
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+161
-0
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+0
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+3
-1
python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
...lang/srt/layers/attention/triton_ops/prefill_attention.py
+0
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-1
scripts/deprecated/test_flashinfer.py
scripts/deprecated/test_flashinfer.py
+3
-3
test/srt/test_create_kvindices.py
test/srt/test_create_kvindices.py
+3
-1
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+5
-3
No files found.
python/sglang/srt/layers/attention/__init__.py
0 → 100644
View file @
99ec439d
from
abc
import
ABC
,
abstractmethod
from
torch
import
nn
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
AttentionBackend
(
ABC
):
"""The base class of attention backends"""
@
abstractmethod
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
raise
NotImplementedError
()
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Init the global shared states for cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
def
get_cuda_graph_seq_len_fill_value
(
self
):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise
NotImplementedError
()
def
forward
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run forward on an attention layer."""
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for decode."""
raise
NotImplementedError
()
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for extend."""
raise
NotImplementedError
()
python/sglang/srt/layers/attention_backend.py
→
python/sglang/srt/layers/attention
/flashinfer
_backend.py
View file @
99ec439d
...
@@ -7,15 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
...
@@ -7,15 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
"""
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.
flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.layers.
attention
import
AttentionBackend
from
sglang.srt.
managers.schedule_batch
import
global_server_args_
dic
t
from
sglang.srt.
layers.attention.flashinfer_utils
import
update_flashinfer_in
dic
es
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
...
@@ -33,50 +32,6 @@ if not is_hip():
...
@@ -33,50 +32,6 @@ if not is_hip():
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
class
AttentionBackend
(
ABC
):
"""The base class of attention backends"""
@
abstractmethod
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
raise
NotImplementedError
()
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Init the global shared states for cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
def
get_cuda_graph_seq_len_fill_value
(
self
):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise
NotImplementedError
()
def
forward
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run forward on an attention layer."""
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for decode."""
raise
NotImplementedError
()
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for extend."""
raise
NotImplementedError
()
class
FlashInferAttnBackend
(
AttentionBackend
):
class
FlashInferAttnBackend
(
AttentionBackend
):
"""Flashinfer attention kernels."""
"""Flashinfer attention kernels."""
...
@@ -329,151 +284,3 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -329,151 +284,3 @@ class FlashInferAttnBackend(AttentionBackend):
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.triton_attention.decode_attention
import
(
decode_attention_fwd
,
)
from
sglang.srt.layers.triton_attention.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
if
global_server_args_dict
.
get
(
"triton_attention_reduce_in_fp32"
,
False
):
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
else
:
start_loc
=
attn_logits
=
max_seq_len
=
None
prefix_lens
=
forward_batch
.
extend_prefix_lens
max_extend_len
=
torch
.
max
(
forward_batch
.
seq_lens
-
prefix_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_start_loc
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
start_loc
,
forward_batch
.
seq_lens
,
attn_logits
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
python/sglang/srt/layers/flashinfer_utils.py
→
python/sglang/srt/layers/
attention/
flashinfer_utils.py
View file @
99ec439d
File moved
python/sglang/srt/layers/attention/triton_backend.py
0 → 100644
View file @
99ec439d
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
if
global_server_args_dict
.
get
(
"triton_attention_reduce_in_fp32"
,
False
):
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
else
:
start_loc
=
attn_logits
=
max_seq_len
=
None
prefix_lens
=
forward_batch
.
extend_prefix_lens
max_extend_len
=
torch
.
max
(
forward_batch
.
seq_lens
-
prefix_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_start_loc
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
start_loc
,
forward_batch
.
seq_lens
,
attn_logits
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
python/sglang/srt/layers/
triton_
attention/decode_attention.py
→
python/sglang/srt/layers/attention/
triton_ops/
decode_attention.py
View file @
99ec439d
File moved
python/sglang/srt/layers/
triton_
attention/extend_attention.py
→
python/sglang/srt/layers/attention/
triton_ops/
extend_attention.py
View file @
99ec439d
...
@@ -22,7 +22,9 @@ import torch
...
@@ -22,7 +22,9 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.triton_attention.prefill_attention
import
context_attention_fwd
from
sglang.srt.layers.attention.triton_ops.prefill_attention
import
(
context_attention_fwd
,
)
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
...
...
python/sglang/srt/layers/
triton_
attention/prefill_attention.py
→
python/sglang/srt/layers/attention/
triton_ops/
prefill_attention.py
View file @
99ec439d
File moved
python/sglang/srt/model_executor/forward_batch_info.py
View file @
99ec439d
...
@@ -37,7 +37,7 @@ import numpy as np
...
@@ -37,7 +37,7 @@ import numpy as np
import
torch
import
torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention
_backend
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
ImageInputs
,
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ImageInputs
,
ModelWorkerBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
99ec439d
...
@@ -39,7 +39,8 @@ from vllm.model_executor.models import ModelRegistry
...
@@ -39,7 +39,8 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_manager
import
LoRAManager
...
...
scripts/deprecated/test_flashinfer.py
View file @
99ec439d
...
@@ -6,8 +6,8 @@ from flashinfer import (
...
@@ -6,8 +6,8 @@ from flashinfer import (
)
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
sglang.srt.layers.
token
_attention
import
token
_attention_fwd
from
sglang.srt.layers.
attention.triton_ops.decode
_attention
import
decode
_attention_fwd
from
sglang.srt.layers.
triton_
attention.extend_attention
import
(
from
sglang.srt.layers.attention.
triton_ops.
extend_attention
import
(
extend_attention_fwd
,
extend_attention_fwd
,
redundant_attention
,
redundant_attention
,
)
)
...
@@ -159,7 +159,7 @@ def test_batch_decode_with_paged_kv_cache(
...
@@ -159,7 +159,7 @@ def test_batch_decode_with_paged_kv_cache(
b_seq_len
=
torch
.
full
((
batch_size
,),
kv_len
,
dtype
=
torch
.
int32
).
to
(
0
)
b_seq_len
=
torch
.
full
((
batch_size
,),
kv_len
,
dtype
=
torch
.
int32
).
to
(
0
)
max_len_in_batch
=
kv_len
max_len_in_batch
=
kv_len
other_kv_index
=
0
other_kv_index
=
0
token
_attention_fwd
(
decode
_attention_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
...
...
test/srt/test_create_kvindices.py
View file @
99ec439d
...
@@ -4,7 +4,9 @@ import unittest
...
@@ -4,7 +4,9 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
sglang.srt.layers.flashinfer_utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.flashinfer_utils
import
(
create_flashinfer_kv_indices_triton
,
)
class
TestCreateKvIndices
(
unittest
.
TestCase
):
class
TestCreateKvIndices
(
unittest
.
TestCase
):
...
...
test/srt/test_triton_attention_kernels.py
View file @
99ec439d
...
@@ -3,12 +3,14 @@ import unittest
...
@@ -3,12 +3,14 @@ import unittest
import
torch
import
torch
from
sglang.srt.layers.
triton_
attention.decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.attention.
triton_ops.
decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.
triton_
attention.extend_attention
import
(
from
sglang.srt.layers.attention.
triton_ops.
extend_attention
import
(
extend_attention_fwd
,
extend_attention_fwd
,
redundant_attention
,
redundant_attention
,
)
)
from
sglang.srt.layers.triton_attention.prefill_attention
import
context_attention_fwd
from
sglang.srt.layers.attention.triton_ops.prefill_attention
import
(
context_attention_fwd
,
)
class
TestExtendAttention
(
unittest
.
TestCase
):
class
TestExtendAttention
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment