Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
99ec439d
"src/vscode:/vscode.git/clone" did not exist on "b09b152f77b02b4d72d88030593ac9ff7ffa3d81"
Unverified
Commit
99ec439d
authored
Sep 30, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 30, 2024
Browse files
Organize Attention Backends (#1547)
parent
0f4fb19b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
229 additions
and
205 deletions
+229
-205
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+49
-0
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-195
python/sglang/srt/layers/attention/flashinfer_utils.py
python/sglang/srt/layers/attention/flashinfer_utils.py
+0
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+161
-0
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+0
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+3
-1
python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
...lang/srt/layers/attention/triton_ops/prefill_attention.py
+0
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-1
scripts/deprecated/test_flashinfer.py
scripts/deprecated/test_flashinfer.py
+3
-3
test/srt/test_create_kvindices.py
test/srt/test_create_kvindices.py
+3
-1
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+5
-3
No files found.
python/sglang/srt/layers/attention/__init__.py
0 → 100644
View file @
99ec439d
from
abc
import
ABC
,
abstractmethod
from
torch
import
nn
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
AttentionBackend
(
ABC
):
"""The base class of attention backends"""
@
abstractmethod
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
raise
NotImplementedError
()
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Init the global shared states for cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
def
get_cuda_graph_seq_len_fill_value
(
self
):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise
NotImplementedError
()
def
forward
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run forward on an attention layer."""
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for decode."""
raise
NotImplementedError
()
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for extend."""
raise
NotImplementedError
()
python/sglang/srt/layers/attention_backend.py
→
python/sglang/srt/layers/attention
/flashinfer
_backend.py
View file @
99ec439d
...
...
@@ -7,15 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
sglang.global_config
import
global_config
from
sglang.srt.layers.
flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.
managers.schedule_batch
import
global_server_args_
dic
t
from
sglang.srt.layers.
attention
import
AttentionBackend
from
sglang.srt.
layers.attention.flashinfer_utils
import
update_flashinfer_in
dic
es
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_hip
...
...
@@ -33,50 +32,6 @@ if not is_hip():
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
class
AttentionBackend
(
ABC
):
"""The base class of attention backends"""
@
abstractmethod
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
raise
NotImplementedError
()
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Init the global shared states for cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
def
get_cuda_graph_seq_len_fill_value
(
self
):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise
NotImplementedError
()
def
forward
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run forward on an attention layer."""
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for decode."""
raise
NotImplementedError
()
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
"""Run a forward for extend."""
raise
NotImplementedError
()
class
FlashInferAttnBackend
(
AttentionBackend
):
"""Flashinfer attention kernels."""
...
...
@@ -329,151 +284,3 @@ class FlashInferAttnBackend(AttentionBackend):
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.triton_attention.decode_attention
import
(
decode_attention_fwd
,
)
from
sglang.srt.layers.triton_attention.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
if
global_server_args_dict
.
get
(
"triton_attention_reduce_in_fp32"
,
False
):
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
else
:
start_loc
=
attn_logits
=
max_seq_len
=
None
prefix_lens
=
forward_batch
.
extend_prefix_lens
max_extend_len
=
torch
.
max
(
forward_batch
.
seq_lens
-
prefix_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_start_loc
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
start_loc
,
forward_batch
.
seq_lens
,
attn_logits
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
python/sglang/srt/layers/flashinfer_utils.py
→
python/sglang/srt/layers/
attention/
flashinfer_utils.py
View file @
99ec439d
File moved
python/sglang/srt/layers/attention/triton_backend.py
0 → 100644
View file @
99ec439d
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
if
global_server_args_dict
.
get
(
"triton_attention_reduce_in_fp32"
,
False
):
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
else
:
start_loc
=
attn_logits
=
max_seq_len
=
None
prefix_lens
=
forward_batch
.
extend_prefix_lens
max_extend_len
=
torch
.
max
(
forward_batch
.
seq_lens
-
prefix_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_start_loc
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
)
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
start_loc
,
forward_batch
.
seq_lens
,
attn_logits
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
python/sglang/srt/layers/
triton_
attention/decode_attention.py
→
python/sglang/srt/layers/attention/
triton_ops/
decode_attention.py
View file @
99ec439d
File moved
python/sglang/srt/layers/
triton_
attention/extend_attention.py
→
python/sglang/srt/layers/attention/
triton_ops/
extend_attention.py
View file @
99ec439d
...
...
@@ -22,7 +22,9 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.triton_attention.prefill_attention
import
context_attention_fwd
from
sglang.srt.layers.attention.triton_ops.prefill_attention
import
(
context_attention_fwd
,
)
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
...
...
python/sglang/srt/layers/
triton_
attention/prefill_attention.py
→
python/sglang/srt/layers/attention/
triton_ops/
prefill_attention.py
View file @
99ec439d
File moved
python/sglang/srt/model_executor/forward_batch_info.py
View file @
99ec439d
...
...
@@ -37,7 +37,7 @@ import numpy as np
import
torch
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention
_backend
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
ImageInputs
,
ModelWorkerBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
99ec439d
...
...
@@ -39,7 +39,8 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
...
...
scripts/deprecated/test_flashinfer.py
View file @
99ec439d
...
...
@@ -6,8 +6,8 @@ from flashinfer import (
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
sglang.srt.layers.
token
_attention
import
token
_attention_fwd
from
sglang.srt.layers.
triton_
attention.extend_attention
import
(
from
sglang.srt.layers.
attention.triton_ops.decode
_attention
import
decode
_attention_fwd
from
sglang.srt.layers.attention.
triton_ops.
extend_attention
import
(
extend_attention_fwd
,
redundant_attention
,
)
...
...
@@ -159,7 +159,7 @@ def test_batch_decode_with_paged_kv_cache(
b_seq_len
=
torch
.
full
((
batch_size
,),
kv_len
,
dtype
=
torch
.
int32
).
to
(
0
)
max_len_in_batch
=
kv_len
other_kv_index
=
0
token
_attention_fwd
(
decode
_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
...
...
test/srt/test_create_kvindices.py
View file @
99ec439d
...
...
@@ -4,7 +4,9 @@ import unittest
import
numpy
as
np
import
torch
from
sglang.srt.layers.flashinfer_utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.flashinfer_utils
import
(
create_flashinfer_kv_indices_triton
,
)
class
TestCreateKvIndices
(
unittest
.
TestCase
):
...
...
test/srt/test_triton_attention_kernels.py
View file @
99ec439d
...
...
@@ -3,12 +3,14 @@ import unittest
import
torch
from
sglang.srt.layers.
triton_
attention.decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.
triton_
attention.extend_attention
import
(
from
sglang.srt.layers.attention.
triton_ops.
decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.attention.
triton_ops.
extend_attention
import
(
extend_attention_fwd
,
redundant_attention
,
)
from
sglang.srt.layers.triton_attention.prefill_attention
import
context_attention_fwd
from
sglang.srt.layers.attention.triton_ops.prefill_attention
import
(
context_attention_fwd
,
)
class
TestExtendAttention
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment