Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
60fdad7c
Unverified
Commit
60fdad7c
authored
Jun 06, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 06, 2025
Browse files
Sync the changes on cuda graph runners (#6932)
parent
61ce91ed
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
70 deletions
+63
-70
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-2
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+21
-30
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-2
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+8
-8
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+3
-8
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+3
-9
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+13
-11
No files found.
python/sglang/srt/managers/io_struct.py
View file @
60fdad7c
...
@@ -20,7 +20,7 @@ import copy
...
@@ -20,7 +20,7 @@ import copy
import
uuid
import
uuid
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
sglang.srt.mm_utils
import
has_valid_data
from
sglang.srt.mm_utils
import
has_valid_data
...
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
...
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
else
:
else
:
Image
=
Any
Image
=
Any
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
,
flatten_nested_list
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
60fdad7c
...
@@ -259,23 +259,8 @@ class CudaGraphRunner:
...
@@ -259,23 +259,8 @@ class CudaGraphRunner:
}
}
# Speculative_inference
# Speculative_inference
if
(
if
model_runner
.
spec_algorithm
.
is_eagle3
():
model_runner
.
spec_algorithm
.
is_eagle3
()
and
not
model_runner
.
is_draft_worker
):
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_token
,
3
*
self
.
model_runner
.
model_config
.
hidden_size
,
),
dtype
=
self
.
model_runner
.
dtype
,
)
self
.
model_runner
.
model
.
set_eagle3_layers_to_capture
()
self
.
model_runner
.
model
.
set_eagle3_layers_to_capture
()
elif
model_runner
.
spec_algorithm
.
is_eagle
():
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
self
.
model_runner
.
dtype
,
)
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
...
@@ -284,6 +269,7 @@ class CudaGraphRunner:
...
@@ -284,6 +269,7 @@ class CudaGraphRunner:
)
)
else
:
else
:
self
.
encoder_lens
=
None
self
.
encoder_lens
=
None
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self
.
gathered_buffer
=
torch
.
zeros
(
self
.
gathered_buffer
=
torch
.
zeros
(
...
@@ -303,13 +289,7 @@ class CudaGraphRunner:
...
@@ -303,13 +289,7 @@ class CudaGraphRunner:
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture CUDA graph failed:
{
e
}
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
)
@
contextmanager
@
contextmanager
...
@@ -439,6 +419,7 @@ class CudaGraphRunner:
...
@@ -439,6 +419,7 @@ class CudaGraphRunner:
self
.
capture_hidden_mode
=
(
self
.
capture_hidden_mode
=
(
spec_info
.
capture_hidden_mode
if
spec_info
else
CaptureHiddenMode
.
NULL
spec_info
.
capture_hidden_mode
if
spec_info
else
CaptureHiddenMode
.
NULL
)
)
if
self
.
model_runner
.
server_args
.
lora_paths
is
not
None
:
if
self
.
model_runner
.
server_args
.
lora_paths
is
not
None
:
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
...
@@ -467,9 +448,9 @@ class CudaGraphRunner:
...
@@ -467,9 +448,9 @@ class CudaGraphRunner:
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
spec_info
=
spec_info
,
capture_hidden_mode
=
self
.
capture_hidden_mode
,
capture_hidden_mode
=
self
.
capture_hidden_mode
,
lora_paths
=
lora_paths
,
num_token_non_padded
=
self
.
num_token_non_padded
,
num_token_non_padded
=
self
.
num_token_non_padded
,
global_forward_mode
=
self
.
capture_forward_mode
,
global_forward_mode
=
self
.
capture_forward_mode
,
lora_paths
=
lora_paths
,
)
)
self
.
tbo_plugin
.
capture_one_batch_size
(
forward_batch
,
num_tokens
=
num_tokens
)
self
.
tbo_plugin
.
capture_one_batch_size
(
forward_batch
,
num_tokens
=
num_tokens
)
...
@@ -497,7 +478,9 @@ class CudaGraphRunner:
...
@@ -497,7 +478,9 @@ class CudaGraphRunner:
self
.
pp_size
>
1
self
.
pp_size
>
1
and
"pp_proxy_tensors"
in
inspect
.
signature
(
forward
).
parameters
and
"pp_proxy_tensors"
in
inspect
.
signature
(
forward
).
parameters
):
):
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
kwargs
[
"pp_proxy_tensors"
]
=
PPProxyTensors
(
{
k
:
v
.
clone
()
for
k
,
v
in
pp_proxy_tensors
.
tensors
.
items
()}
)
logits_output_or_pp_proxy_tensors
=
forward
(
logits_output_or_pp_proxy_tensors
=
forward
(
input_ids
,
input_ids
,
...
@@ -590,9 +573,6 @@ class CudaGraphRunner:
...
@@ -590,9 +573,6 @@ class CudaGraphRunner:
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
self
.
global_num_tokens_gpu
.
copy_
(
forward_batch
.
global_num_tokens_gpu
)
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
...
@@ -650,7 +630,7 @@ class CudaGraphRunner:
...
@@ -650,7 +630,7 @@ class CudaGraphRunner:
else
:
else
:
spec_info
=
EagleVerifyInput
(
spec_info
=
EagleVerifyInput
(
draft_token
=
None
,
draft_token
=
None
,
custom_mask
=
torch
.
zero
s
(
custom_mask
=
torch
.
one
s
(
(
num_tokens
*
self
.
model_runner
.
model_config
.
context_len
),
(
num_tokens
*
self
.
model_runner
.
model_config
.
context_len
),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -660,9 +640,20 @@ class CudaGraphRunner:
...
@@ -660,9 +640,20 @@ class CudaGraphRunner:
retrive_next_token
=
None
,
retrive_next_token
=
None
,
retrive_next_sibling
=
None
,
retrive_next_sibling
=
None
,
retrive_cum_len
=
None
,
retrive_cum_len
=
None
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
spec_steps
=
self
.
model_runner
.
server_args
.
speculative_num_steps
,
spec_steps
=
self
.
model_runner
.
server_args
.
speculative_num_steps
,
topk
=
self
.
model_runner
.
server_args
.
speculative_eagle_topk
,
draft_token_num
=
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
)
)
return
spec_info
return
spec_info
CUDA_GRAPH_CAPTURE_FAILED_MSG
=
(
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
python/sglang/srt/server_args.py
View file @
60fdad7c
...
@@ -447,7 +447,7 @@ class ServerArgs:
...
@@ -447,7 +447,7 @@ class ServerArgs:
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
speculative_eagle_topk
,
self
.
speculative_eagle_topk
,
self
.
speculative_num_draft_tokens
,
self
.
speculative_num_draft_tokens
,
)
=
auto_choose_speculative_params
(
model_arch
)
)
=
auto_choose_speculative_params
(
self
)
if
self
.
page_size
>
1
and
self
.
speculative_eagle_topk
>
1
:
if
self
.
page_size
>
1
and
self
.
speculative_eagle_topk
>
1
:
self
.
speculative_eagle_topk
=
1
self
.
speculative_eagle_topk
=
1
...
@@ -1655,12 +1655,23 @@ def get_model_arch(args: ServerArgs):
...
@@ -1655,12 +1655,23 @@ def get_model_arch(args: ServerArgs):
return
hf_config
.
architectures
[
0
]
return
hf_config
.
architectures
[
0
]
def
auto_choose_speculative_params
(
arch
:
str
):
def
auto_choose_speculative_params
(
self
:
ServerArgs
):
"""
"""
Automatically choose the parameters for speculative decoding.
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
"""
kwargs
=
{}
hf_config
=
get_config
(
self
.
model_path
,
trust_remote_code
=
self
.
trust_remote_code
,
revision
=
self
.
revision
,
model_override_args
=
json
.
loads
(
self
.
json_model_override_args
),
**
kwargs
,
)
arch
=
hf_config
.
architectures
[
0
]
if
arch
in
[
"LlamaForCausalLM"
]:
if
arch
in
[
"LlamaForCausalLM"
]:
# The default value for llama
# The default value for llama
return
(
5
,
4
,
8
)
return
(
5
,
4
,
8
)
...
...
python/sglang/srt/speculative/build_eagle_tree.py
View file @
60fdad7c
...
@@ -4,7 +4,7 @@ from typing import List
...
@@ -4,7 +4,7 @@ from typing import List
import
torch
import
torch
from
sglang.srt.utils
import
is_cuda
,
is_hip
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
rank0_print
if
is_cuda
()
or
is_hip
():
if
is_cuda
()
or
is_hip
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
...
@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
num_verify_tokens
=
num_draft_token
,
num_verify_tokens
=
num_draft_token
,
)
)
first_
rank_print
(
"=========== build tree kernel efficient =========="
)
rank
0
_print
(
"=========== build tree kernel efficient =========="
)
#
first_
rank_print(f"{tree_mask=}", flush=True)
# rank
0
_print(f"{tree_mask=}", flush=True)
first_
rank_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
retrive_next_token
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
retrive_next_token
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
retrive_next_sibling
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
retrive_next_sibling
=
}
"
,
flush
=
True
)
first_
rank_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
rank
0
_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
assert
retrive_index
.
tolist
()
==
[
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
60fdad7c
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.model_executor.cuda_graph_runner
import
(
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
CudaGraphRunner
,
get_batch_sizes_to_capture
,
get_batch_sizes_to_capture
,
get_global_graph_memory_pool
,
get_global_graph_memory_pool
,
...
@@ -73,7 +74,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -73,7 +74,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
topk_p
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
float32
)
self
.
topk_p
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
float32
)
self
.
topk_index
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
int64
)
self
.
topk_index
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
int64
)
self
.
hidden_states
=
torch
.
zeros
(
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_
bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
(
self
.
max_
num_token
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
self
.
model_runner
.
dtype
,
dtype
=
self
.
model_runner
.
dtype
,
)
)
...
@@ -82,13 +83,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -82,13 +83,7 @@ class EAGLEDraftCudaGraphRunner:
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture CUDA graph failed:
{
e
}
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
60fdad7c
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Callable
import
torch
import
torch
from
sglang.srt.model_executor.cuda_graph_runner
import
(
from
sglang.srt.model_executor.cuda_graph_runner
import
(
CUDA_GRAPH_CAPTURE_FAILED_MSG
,
CudaGraphRunner
,
CudaGraphRunner
,
LogitsProcessorOutput
,
LogitsProcessorOutput
,
get_batch_sizes_to_capture
,
get_batch_sizes_to_capture
,
...
@@ -89,13 +90,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -89,13 +90,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
capture
()
self
.
capture
()
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture CUDA graph failed:
{
e
}
\n
"
f
"Capture cuda graph failed:
{
e
}
\n
{
CUDA_GRAPH_CAPTURE_FAILED_MSG
}
"
"Possible solutions:
\n
"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
\n
"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
)
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
...
@@ -200,7 +195,6 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -200,7 +195,6 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs
# in the batch, which will not be counted as num_seqs
raw_bs
=
forward_batch
.
batch_size
raw_bs
=
forward_batch
.
batch_size
num_tokens
=
forward_batch
.
input_ids
.
shape
[
0
]
num_tokens
=
forward_batch
.
input_ids
.
shape
[
0
]
assert
raw_bs
*
self
.
num_tokens_per_bs
==
num_tokens
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
bs
=
self
.
capture_bs
[
index
]
...
@@ -224,9 +218,9 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -224,9 +218,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
forward_batch
.
spec_info
.
positions
=
None
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
forward_batch
.
spec_info
.
accept_length
=
self
.
accept_length
[:
bs
]
forward_batch
.
spec_info
.
accept_length
=
self
.
accept_length
[:
bs
]
forward_batch
.
spec_info
.
positions
=
None
self
.
eagle_worker
.
draft_extend_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
eagle_worker
.
draft_extend_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
=
bs
,
bs
=
bs
,
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
60fdad7c
...
@@ -232,8 +232,9 @@ class EagleVerifyInput:
...
@@ -232,8 +232,9 @@ class EagleVerifyInput:
retrive_next_token
:
torch
.
Tensor
retrive_next_token
:
torch
.
Tensor
retrive_next_sibling
:
torch
.
Tensor
retrive_next_sibling
:
torch
.
Tensor
retrive_cum_len
:
torch
.
Tensor
retrive_cum_len
:
torch
.
Tensor
draft_token_num
:
int
spec_steps
:
int
spec_steps
:
int
topk
:
int
draft_token_num
:
int
capture_hidden_mode
:
CaptureHiddenMode
capture_hidden_mode
:
CaptureHiddenMode
grammar
:
BaseGrammarObject
=
None
grammar
:
BaseGrammarObject
=
None
...
@@ -270,16 +271,17 @@ class EagleVerifyInput:
...
@@ -270,16 +271,17 @@ class EagleVerifyInput:
)
)
return
cls
(
return
cls
(
draft_tokens
,
draft_token
=
draft_tokens
,
tree_mask
,
custom_mask
=
tree_mask
,
position
,
positions
=
position
,
retrive_index
,
retrive_index
=
retrive_index
,
retrive_next_token
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
,
retrive_next_sibling
=
retrive_next_sibling
,
None
,
retrive_cum_len
=
None
,
num_verify_tokens
,
spec_steps
=
spec_steps
,
spec_steps
,
topk
=
topk
,
CaptureHiddenMode
.
FULL
,
draft_token_num
=
num_verify_tokens
,
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
,
)
)
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment