Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2d611323
Unverified
Commit
2d611323
authored
Feb 10, 2025
by
Ke Bao
Committed by
GitHub
Feb 10, 2025
Browse files
Support Eagle2 for Triton backend (#3466)
parent
cddb1cdf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
286 additions
and
42 deletions
+286
-42
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+223
-24
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+4
-4
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+24
-8
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+29
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+6
-6
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
2d611323
...
@@ -3,6 +3,7 @@ from __future__ import annotations
...
@@ -3,6 +3,7 @@ from __future__ import annotations
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
triton
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
...
@@ -18,7 +19,12 @@ if TYPE_CHECKING:
...
@@ -18,7 +19,12 @@ if TYPE_CHECKING:
class
TritonAttnBackend
(
AttentionBackend
):
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Lazy import to avoid the initialization of cuda context
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
decode_attention_fwd
,
...
@@ -33,14 +39,25 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -33,14 +39,25 @@ class TritonAttnBackend(AttentionBackend):
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
max_bs
=
model_runner
.
req_to_token_pool
.
size
max_bs
=
model_runner
.
req_to_token_pool
.
size
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
if
kv_indptr_buf
is
None
:
)
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
else
:
self
.
kv_indptr
=
kv_indptr_buf
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
qo_indptr
=
torch
.
zeros
(
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
)
self
.
mask_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int64
,
device
=
model_runner
.
device
)
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
num_head
=
(
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
)
...
@@ -50,7 +67,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -50,7 +67,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq
_len
=
model_runner
.
model_config
.
context_len
self
.
max_context
_len
=
model_runner
.
model_config
.
context_len
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
...
@@ -59,11 +76,31 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -59,11 +76,31 @@ class TritonAttnBackend(AttentionBackend):
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
kv_indptr
=
self
.
kv_indptr
kv_indptr
=
self
.
kv_indptr
spec_info
=
forward_batch
.
spec_info
if
forward_batch
.
forward_mode
.
is_decode
():
attn_logits
=
torch
.
empty
(
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
attn_logits
=
torch
.
zeros
(
(
(
forward_batch
.
batch_size
,
bs
,
self
.
num_head
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
,
self
.
v_head_dim
+
1
,
...
@@ -72,12 +109,24 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -72,12 +109,24 @@ class TritonAttnBackend(AttentionBackend):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
qo_indptr
=
None
custom_mask
=
None
mask_indptr
=
None
max_extend_len
=
None
max_extend_len
=
None
elif
forward_batch
.
forward_mode
.
is_target_verify
():
bs
=
len
(
forward_batch
.
req_pool_indices
)
qo_indptr
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Different with flashinfer kv_indptr and kv_indices construction
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
zeros
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
kv_indptr
[
-
1
]
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
...
@@ -89,15 +138,32 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -89,15 +138,32 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
qo_indptr
=
None
custom_mask
=
spec_info
.
custom_mask
custom_mask
=
None
seq_mask_len
=
self
.
num_draft_tokens
*
(
mask_offsets
=
None
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
[:
bs
],
dim
=
0
)
mask_indptr
=
mask_indptr
[:
bs
+
1
]
max_extend_len
=
self
.
num_draft_tokens
attn_logits
=
None
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
self
.
req_to_token
,
)
)
mask_indptr
=
None
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
attn_logits
=
None
else
:
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_prefix_lens
,
dim
=
0
forward_batch
.
extend_prefix_lens
,
dim
=
0
)
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
zeros
(
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -116,8 +182,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -116,8 +182,7 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
custom_mask
=
None
mask_offsets
=
None
mask_indptr
=
None
attn_logits
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
...
@@ -128,22 +193,22 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -128,22 +193,22 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
)
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq
_len
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
max_context
_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
cuda_graph_max_seq
_len
),
(
max_bs
*
self
.
max_context
_len
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
...
@@ -244,8 +309,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -244,8 +309,9 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
)
=
self
.
forward_metadata
)
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
...
@@ -257,7 +323,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -257,7 +323,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
max_extend_len
,
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
...
@@ -303,3 +369,136 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -303,3 +369,136 @@ class TritonAttnBackend(AttentionBackend):
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
return
o
return
o
class
TritonMultiStepDraftBackend
:
"""
Wrap multiple triton attention backends as one for multiple consecutive
draft decoding steps.
"""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
from
sglang.srt.speculative.eagle_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
+
1
,
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
TritonAttnBackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
torch
.
Tensor
,
call_fn
:
int
):
num_seqs
=
forward_batch
.
batch_size
bs
=
self
.
topk
*
num_seqs
seq_lens_sum
=
forward_batch
.
seq_lens_sum
self
.
generate_draft_decode_kv_indices
[
(
self
.
speculative_num_steps
,
num_seqs
,
self
.
topk
)
](
forward_batch
.
req_pool_indices
,
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
seq_lens
,
kv_indices_buffer
,
self
.
kv_indptr
,
forward_batch
.
positions
,
num_seqs
,
self
.
topk
,
self
.
pool_len
,
kv_indices_buffer
.
shape
[
1
],
self
.
kv_indptr
.
shape
[
1
],
triton
.
next_power_of_2
(
num_seqs
),
triton
.
next_power_of_2
(
self
.
speculative_num_steps
),
triton
.
next_power_of_2
(
bs
),
)
for
i
in
range
(
self
.
speculative_num_steps
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
]
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
def
call_fn
(
i
,
forward_batch
):
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
)
forward_batch
.
spec_info
.
kv_indices
=
(
forward_batch
.
spec_info
.
kv_indices
.
clone
()
)
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
common_template
(
forward_batch
,
kv_indices
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
2d611323
...
@@ -50,7 +50,7 @@ def _fwd_kernel(
...
@@ -50,7 +50,7 @@ def _fwd_kernel(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
mask_ptr
,
mask_ptr
,
mask_
offsets
,
mask_
indptr
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
stride_qbs
,
stride_qbs
,
...
@@ -87,7 +87,7 @@ def _fwd_kernel(
...
@@ -87,7 +87,7 @@ def _fwd_kernel(
cur_seq_len
=
cur_seq_len_prefix
+
cur_seq_len_extend
cur_seq_len
=
cur_seq_len_prefix
+
cur_seq_len_extend
if
USE_CUSTOM_MASK
:
if
USE_CUSTOM_MASK
:
cur_seq_mask_start_idx
=
tl
.
load
(
mask_
offsets
+
cur_seq
)
cur_seq_mask_start_idx
=
tl
.
load
(
mask_
indptr
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
...
@@ -288,7 +288,7 @@ def extend_attention_fwd(
...
@@ -288,7 +288,7 @@ def extend_attention_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
max_len_extend
,
max_len_extend
,
sm_scale
=
None
,
sm_scale
=
None
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
...
@@ -364,7 +364,7 @@ def extend_attention_fwd(
...
@@ -364,7 +364,7 @@ def extend_attention_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
0
),
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
2d611323
...
@@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker):
...
@@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker):
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners
# Create multi-step attn backends and cuda graph runners
from
sglang.srt.layers.attention.flashinfer_backend
import
(
if
server_args
.
attention_backend
==
"flashinfer"
:
FlashInferMultiStepDraftBackend
,
from
sglang.srt.layers.attention.flashinfer_backend
import
(
)
FlashInferMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
FlashInferMultiStepDraftBackend
(
self
.
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
elif
server_args
.
attention_backend
==
"triton"
:
from
sglang.srt.layers.attention.triton_backend
import
(
TritonMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
TritonMultiStepDraftBackend
(
self
.
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
else
:
raise
ValueError
(
f
"EAGLE is not supportted in attention backend
{
server_args
.
attention_backend
}
"
)
self
.
draft_attn_backend
=
FlashInferMultiStepDraftBackend
(
self
.
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
init_cuda_graphs
()
self
.
init_cuda_graphs
()
...
...
test/srt/test_eagle_infer.py
View file @
2d611323
...
@@ -193,5 +193,34 @@ class TestEAGLEServer(unittest.TestCase):
...
@@ -193,5 +193,34 @@ class TestEAGLEServer(unittest.TestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
class
TestEAGLEServerTriton
(
TestEAGLEServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft-model-path"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"5"
,
"--speculative-eagle-topk"
,
"8"
,
"--speculative-num-draft-tokens"
,
"64"
,
"--mem-fraction-static"
,
"0.7"
,
"--attention-backend"
,
"triton"
,
# TODO: Support cuda graph
"--disable-cuda-graph"
,
],
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
test/srt/test_triton_attention_kernels.py
View file @
2d611323
...
@@ -102,7 +102,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -102,7 +102,7 @@ class TestTritonAttention(unittest.TestCase):
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
custom_mask
=
None
custom_mask
=
None
mask_
offsets
=
None
mask_
indptr
=
None
extend_attention_fwd
(
extend_attention_fwd
(
q_extend
,
q_extend
,
...
@@ -115,7 +115,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -115,7 +115,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
max_len_extend
,
max_len_extend
,
)
)
...
@@ -123,8 +123,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -123,8 +123,8 @@ class TestTritonAttention(unittest.TestCase):
custom_mask
=
torch
.
ones
(
custom_mask
=
torch
.
ones
(
(
b_seq_mask_len
.
sum
().
item
(),),
dtype
=
torch
.
bool
,
device
=
"cuda"
(
b_seq_mask_len
.
sum
().
item
(),),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
)
mask_
offsets
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
mask_
indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
mask_
offsets
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_mask_len
[:
B
],
dim
=
0
)
mask_
indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_mask_len
[:
B
],
dim
=
0
)
for
i
in
range
(
B
):
for
i
in
range
(
B
):
causal_mask
=
(
causal_mask
=
(
torch
.
tril
(
torch
.
tril
(
...
@@ -136,7 +136,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -136,7 +136,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len_extend
[
i
],
b_seq_len_prefix
[
i
],
dtype
=
torch
.
bool
b_seq_len_extend
[
i
],
b_seq_len_prefix
[
i
],
dtype
=
torch
.
bool
)
)
mask_flatten
=
torch
.
cat
([
prefix_mask
,
causal_mask
],
dim
=
1
).
flatten
()
mask_flatten
=
torch
.
cat
([
prefix_mask
,
causal_mask
],
dim
=
1
).
flatten
()
custom_mask
[
mask_
offsets
[
i
]
:
mask_
offsets
[
i
+
1
]]
=
mask_flatten
custom_mask
[
mask_
indptr
[
i
]
:
mask_
indptr
[
i
+
1
]]
=
mask_flatten
extend_attention_fwd
(
extend_attention_fwd
(
q_extend
,
q_extend
,
...
@@ -149,7 +149,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -149,7 +149,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
max_len_extend
,
max_len_extend
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment