Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2d611323
"vscode:/vscode.git/clone" did not exist on "92ea5baca2815ecd51f96bedb0fb766b313196f8"
Unverified
Commit
2d611323
authored
Feb 10, 2025
by
Ke Bao
Committed by
GitHub
Feb 10, 2025
Browse files
Support Eagle2 for Triton backend (#3466)
parent
cddb1cdf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
286 additions
and
42 deletions
+286
-42
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+223
-24
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+4
-4
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+24
-8
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+29
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+6
-6
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
2d611323
...
@@ -3,6 +3,7 @@ from __future__ import annotations
...
@@ -3,6 +3,7 @@ from __future__ import annotations
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
triton
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
...
@@ -18,7 +19,12 @@ if TYPE_CHECKING:
...
@@ -18,7 +19,12 @@ if TYPE_CHECKING:
class
TritonAttnBackend
(
AttentionBackend
):
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Lazy import to avoid the initialization of cuda context
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
decode_attention_fwd
,
...
@@ -33,14 +39,25 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -33,14 +39,25 @@ class TritonAttnBackend(AttentionBackend):
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
max_bs
=
model_runner
.
req_to_token_pool
.
size
max_bs
=
model_runner
.
req_to_token_pool
.
size
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
if
kv_indptr_buf
is
None
:
)
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
else
:
self
.
kv_indptr
=
kv_indptr_buf
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
qo_indptr
=
torch
.
zeros
(
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
)
self
.
mask_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int64
,
device
=
model_runner
.
device
)
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
num_head
=
(
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
)
...
@@ -50,7 +67,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -50,7 +67,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq
_len
=
model_runner
.
model_config
.
context_len
self
.
max_context
_len
=
model_runner
.
model_config
.
context_len
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
...
@@ -59,11 +76,31 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -59,11 +76,31 @@ class TritonAttnBackend(AttentionBackend):
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
kv_indptr
=
self
.
kv_indptr
kv_indptr
=
self
.
kv_indptr
spec_info
=
forward_batch
.
spec_info
if
forward_batch
.
forward_mode
.
is_decode
():
attn_logits
=
torch
.
empty
(
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
attn_logits
=
torch
.
zeros
(
(
(
forward_batch
.
batch_size
,
bs
,
self
.
num_head
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
,
self
.
v_head_dim
+
1
,
...
@@ -72,12 +109,24 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -72,12 +109,24 @@ class TritonAttnBackend(AttentionBackend):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
qo_indptr
=
None
custom_mask
=
None
mask_indptr
=
None
max_extend_len
=
None
max_extend_len
=
None
elif
forward_batch
.
forward_mode
.
is_target_verify
():
bs
=
len
(
forward_batch
.
req_pool_indices
)
qo_indptr
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Different with flashinfer kv_indptr and kv_indices construction
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
zeros
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
kv_indptr
[
-
1
]
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
...
@@ -89,15 +138,32 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -89,15 +138,32 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
qo_indptr
=
None
custom_mask
=
spec_info
.
custom_mask
custom_mask
=
None
seq_mask_len
=
self
.
num_draft_tokens
*
(
mask_offsets
=
None
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
)
mask_indptr
=
self
.
mask_indptr
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
[:
bs
],
dim
=
0
)
mask_indptr
=
mask_indptr
[:
bs
+
1
]
max_extend_len
=
self
.
num_draft_tokens
attn_logits
=
None
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
self
.
req_to_token
,
)
)
mask_indptr
=
None
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
attn_logits
=
None
else
:
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_prefix_lens
,
dim
=
0
forward_batch
.
extend_prefix_lens
,
dim
=
0
)
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
zeros
(
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -116,8 +182,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -116,8 +182,7 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
custom_mask
=
None
mask_offsets
=
None
mask_indptr
=
None
attn_logits
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
...
@@ -128,22 +193,22 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -128,22 +193,22 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
)
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq
_len
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
max_context
_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
cuda_graph_max_seq
_len
),
(
max_bs
*
self
.
max_context
_len
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
...
@@ -244,8 +309,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -244,8 +309,9 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
)
=
self
.
forward_metadata
)
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
...
@@ -257,7 +323,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -257,7 +323,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
max_extend_len
,
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
...
@@ -303,3 +369,136 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -303,3 +369,136 @@ class TritonAttnBackend(AttentionBackend):
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
return
o
return
o
class
TritonMultiStepDraftBackend
:
"""
Wrap multiple triton attention backends as one for multiple consecutive
draft decoding steps.
"""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
from
sglang.srt.speculative.eagle_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
+
1
,
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
TritonAttnBackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
torch
.
Tensor
,
call_fn
:
int
):
num_seqs
=
forward_batch
.
batch_size
bs
=
self
.
topk
*
num_seqs
seq_lens_sum
=
forward_batch
.
seq_lens_sum
self
.
generate_draft_decode_kv_indices
[
(
self
.
speculative_num_steps
,
num_seqs
,
self
.
topk
)
](
forward_batch
.
req_pool_indices
,
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
seq_lens
,
kv_indices_buffer
,
self
.
kv_indptr
,
forward_batch
.
positions
,
num_seqs
,
self
.
topk
,
self
.
pool_len
,
kv_indices_buffer
.
shape
[
1
],
self
.
kv_indptr
.
shape
[
1
],
triton
.
next_power_of_2
(
num_seqs
),
triton
.
next_power_of_2
(
self
.
speculative_num_steps
),
triton
.
next_power_of_2
(
bs
),
)
for
i
in
range
(
self
.
speculative_num_steps
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
]
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
def
call_fn
(
i
,
forward_batch
):
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
)
forward_batch
.
spec_info
.
kv_indices
=
(
forward_batch
.
spec_info
.
kv_indices
.
clone
()
)
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
common_template
(
forward_batch
,
kv_indices
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
2d611323
...
@@ -50,7 +50,7 @@ def _fwd_kernel(
...
@@ -50,7 +50,7 @@ def _fwd_kernel(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
mask_ptr
,
mask_ptr
,
mask_
offsets
,
mask_
indptr
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
stride_qbs
,
stride_qbs
,
...
@@ -87,7 +87,7 @@ def _fwd_kernel(
...
@@ -87,7 +87,7 @@ def _fwd_kernel(
cur_seq_len
=
cur_seq_len_prefix
+
cur_seq_len_extend
cur_seq_len
=
cur_seq_len_prefix
+
cur_seq_len_extend
if
USE_CUSTOM_MASK
:
if
USE_CUSTOM_MASK
:
cur_seq_mask_start_idx
=
tl
.
load
(
mask_
offsets
+
cur_seq
)
cur_seq_mask_start_idx
=
tl
.
load
(
mask_
indptr
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
...
@@ -288,7 +288,7 @@ def extend_attention_fwd(
...
@@ -288,7 +288,7 @@ def extend_attention_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
max_len_extend
,
max_len_extend
,
sm_scale
=
None
,
sm_scale
=
None
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
...
@@ -364,7 +364,7 @@ def extend_attention_fwd(
...
@@ -364,7 +364,7 @@ def extend_attention_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
0
),
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
2d611323
...
@@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker):
...
@@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker):
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners
# Create multi-step attn backends and cuda graph runners
from
sglang.srt.layers.attention.flashinfer_backend
import
(
if
server_args
.
attention_backend
==
"flashinfer"
:
FlashInferMultiStepDraftBackend
,
from
sglang.srt.layers.attention.flashinfer_backend
import
(
)
FlashInferMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
FlashInferMultiStepDraftBackend
(
self
.
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
elif
server_args
.
attention_backend
==
"triton"
:
from
sglang.srt.layers.attention.triton_backend
import
(
TritonMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
TritonMultiStepDraftBackend
(
self
.
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
else
:
raise
ValueError
(
f
"EAGLE is not supportted in attention backend
{
server_args
.
attention_backend
}
"
)
self
.
draft_attn_backend
=
FlashInferMultiStepDraftBackend
(
self
.
model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
init_cuda_graphs
()
self
.
init_cuda_graphs
()
...
...
test/srt/test_eagle_infer.py
View file @
2d611323
...
@@ -193,5 +193,34 @@ class TestEAGLEServer(unittest.TestCase):
...
@@ -193,5 +193,34 @@ class TestEAGLEServer(unittest.TestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
class
TestEAGLEServerTriton
(
TestEAGLEServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft-model-path"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"5"
,
"--speculative-eagle-topk"
,
"8"
,
"--speculative-num-draft-tokens"
,
"64"
,
"--mem-fraction-static"
,
"0.7"
,
"--attention-backend"
,
"triton"
,
# TODO: Support cuda graph
"--disable-cuda-graph"
,
],
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
test/srt/test_triton_attention_kernels.py
View file @
2d611323
...
@@ -102,7 +102,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -102,7 +102,7 @@ class TestTritonAttention(unittest.TestCase):
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
custom_mask
=
None
custom_mask
=
None
mask_
offsets
=
None
mask_
indptr
=
None
extend_attention_fwd
(
extend_attention_fwd
(
q_extend
,
q_extend
,
...
@@ -115,7 +115,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -115,7 +115,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
max_len_extend
,
max_len_extend
,
)
)
...
@@ -123,8 +123,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -123,8 +123,8 @@ class TestTritonAttention(unittest.TestCase):
custom_mask
=
torch
.
ones
(
custom_mask
=
torch
.
ones
(
(
b_seq_mask_len
.
sum
().
item
(),),
dtype
=
torch
.
bool
,
device
=
"cuda"
(
b_seq_mask_len
.
sum
().
item
(),),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
)
mask_
offsets
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
mask_
indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
mask_
offsets
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_mask_len
[:
B
],
dim
=
0
)
mask_
indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_mask_len
[:
B
],
dim
=
0
)
for
i
in
range
(
B
):
for
i
in
range
(
B
):
causal_mask
=
(
causal_mask
=
(
torch
.
tril
(
torch
.
tril
(
...
@@ -136,7 +136,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -136,7 +136,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len_extend
[
i
],
b_seq_len_prefix
[
i
],
dtype
=
torch
.
bool
b_seq_len_extend
[
i
],
b_seq_len_prefix
[
i
],
dtype
=
torch
.
bool
)
)
mask_flatten
=
torch
.
cat
([
prefix_mask
,
causal_mask
],
dim
=
1
).
flatten
()
mask_flatten
=
torch
.
cat
([
prefix_mask
,
causal_mask
],
dim
=
1
).
flatten
()
custom_mask
[
mask_
offsets
[
i
]
:
mask_
offsets
[
i
+
1
]]
=
mask_flatten
custom_mask
[
mask_
indptr
[
i
]
:
mask_
indptr
[
i
+
1
]]
=
mask_flatten
extend_attention_fwd
(
extend_attention_fwd
(
q_extend
,
q_extend
,
...
@@ -149,7 +149,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -149,7 +149,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_
offsets
,
mask_
indptr
,
max_len_extend
,
max_len_extend
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment