Unverified Commit ad20b795 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Eagle speculative decoding part 3: small modifications to the general scheduler (#2709)


Co-authored-by: default avatarkavioyu <kavioyu@tencent.com>
parent 9183c23e
......@@ -61,10 +61,10 @@ For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `
```bash
# node 1
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --nccl-init 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
# node 2
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --nccl-init 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
```
If you have two H100 nodes, the usage is similar to the aforementioned H20.
......
......@@ -63,6 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
......@@ -214,6 +215,7 @@ def extend(reqs, model_runner):
tree_cache=None,
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
)
batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch()
......
......@@ -26,7 +26,7 @@ class AttentionBackend(ABC):
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_token: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
......
......@@ -227,7 +227,7 @@ class FlashInferAttnBackend(AttentionBackend):
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_token: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
......@@ -243,9 +243,11 @@ class FlashInferAttnBackend(AttentionBackend):
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1],
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token],
paged_kv_last_page_len_buffer=self.kv_last_page_len[
:num_tokens
],
)
)
seq_lens_sum = seq_lens.sum().item()
......
......@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_token: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
......
......@@ -575,8 +575,8 @@ class ScheduleBatch:
device: str = "cuda"
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
spec_algorithm: Optional[SpeculativeAlgorithm] = None
@classmethod
def init_new(
......@@ -587,7 +587,7 @@ class ScheduleBatch:
tree_cache: BasePrefixCache,
model_config: ModelConfig,
enable_overlap: bool,
speculative_algorithm: Optional[SpeculativeAlgorithm] = None,
spec_algorithm: SpeculativeAlgorithm,
):
return cls(
reqs=reqs,
......@@ -600,7 +600,7 @@ class ScheduleBatch:
has_stream=any(req.stream for req in reqs),
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
spec_algorithm=speculative_algorithm,
spec_algorithm=spec_algorithm,
)
def batch_size(self):
......@@ -1010,6 +1010,8 @@ class ScheduleBatch:
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
if self.spec_algorithm.is_eagle():
return
self.input_ids = self.output_ids
self.output_ids = None
......@@ -1172,6 +1174,7 @@ class ScheduleBatch:
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
)
def __str__(self):
......@@ -1232,8 +1235,8 @@ class ModelWorkerBatch:
input_embeds: Optional[torch.tensor] = None
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
spec_algorithm: Optional[SpeculativeAlgorithm] = None
@triton.jit
......
......@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
......@@ -116,6 +117,14 @@ class Scheduler:
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.decode_mem_cache_buf_multiplier = (
self.server_args.speculative_num_draft_tokens
if not self.spec_algorithm.is_none()
else 1
)
# Init inter-process communication
context = zmq.Context(2)
......@@ -199,6 +208,21 @@ class Scheduler:
nccl_port=port_args.nccl_port,
)
# Launch worker for speculative decoding if need
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
self.draft_worker = EAGLEWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
# Get token and memory info from the model worker
(
self.max_total_num_tokens,
......@@ -855,6 +879,7 @@ class Scheduler:
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
)
new_batch.prepare_for_extend()
......@@ -888,11 +913,15 @@ class Scheduler:
return None
# Check if decode out of memory
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
test_retract and batch.batch_size() > 10
):
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode()
self.new_token_ratio = new_token_ratio
if self.draft_worker:
self.draft_worker.finish_request(retracted_reqs)
logger.info(
"Decode out of memory happened. "
......@@ -926,11 +955,17 @@ class Scheduler:
self.forward_ct += 1
if self.is_generation:
model_worker_batch = batch.get_model_worker_batch()
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
logits_output, next_token_ids, model_worker_batch, spec_info = (
self.draft_worker.forward_batch_speculative_generation(batch)
)
batch.spec_info = spec_info
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch)
......@@ -1077,7 +1112,10 @@ class Scheduler:
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue
req.output_ids.append(next_token_id)
if batch.spec_algorithm.is_none():
# speculative worker will solve the output_ids in speculative decoding
req.output_ids.append(next_token_id)
req.check_finished()
if req.finished():
......@@ -1252,6 +1290,9 @@ class Scheduler:
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
or (not req.stream and len(req.output_ids) % 50 == 0)
):
if self.draft_worker and req.finished():
self.draft_worker.finish_request(req)
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
......@@ -1383,6 +1424,7 @@ class Scheduler:
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
)
idle_batch.prepare_for_idle()
return idle_batch
......
......@@ -45,13 +45,18 @@ class TpModelWorker:
tp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
is_draft_worker: bool = False,
):
# Parse args
self.tp_rank = tp_rank
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
(
server_args.model_path
if not is_draft_worker
else server_args.speculative_draft_model_path
),
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
......@@ -68,6 +73,7 @@ class TpModelWorker:
tp_size=server_args.tp_size,
nccl_port=nccl_port,
server_args=server_args,
is_draft_worker=is_draft_worker,
)
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
......
......@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -106,11 +106,6 @@ def set_torch_compile_config():
torch._dynamo.config.cache_size_limit = 1024
@maybe_torch_compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
......@@ -157,6 +152,17 @@ class CudaGraphRunner:
self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle():
if self.model_runner.is_draft_worker:
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_eagle_topk
)
else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
self.compile_bs = (
[
bs
......@@ -192,6 +198,13 @@ class CudaGraphRunner:
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
# Speculative_inference
if model_runner.spec_algorithm.is_eagle():
self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self.encoder_lens = torch.full(
......@@ -234,9 +247,6 @@ class CudaGraphRunner:
self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if not forward_batch.forward_mode.is_cuda_graph():
return False
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens
......@@ -291,21 +301,18 @@ class CudaGraphRunner:
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
stream = self.stream
num_token = bs * self.num_tokens_per_bs
num_tokens = bs * self.num_tokens_per_bs
# Common inputs
input_ids = self.input_ids[:num_token]
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:num_token]
positions = self.positions[:num_token]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs]
else:
encoder_lens = None
seq_lens_sum = seq_lens.sum().item()
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
......@@ -325,20 +332,22 @@ class CudaGraphRunner:
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
seq_lens_sum=seq_lens.sum(),
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * num_token,
top_logprobs_nums=[0] * bs,
positions=positions,
global_num_tokens=global_num_tokens,
mrope_positions=mrope_positions,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=self.get_spec_info(num_tokens, positions),
)
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_token,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
......@@ -394,14 +403,16 @@ class CudaGraphRunner:
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
positions = clamp_position(forward_batch.seq_lens)
self.positions[:raw_num_token].copy_(positions)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs,
......@@ -424,3 +435,36 @@ class CudaGraphRunner:
),
)
return logits_output
def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
spec_info = None
if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import (
EAGLEDraftInput,
EagleVerifyInput,
)
if self.model_runner.is_draft_worker:
spec_info = EAGLEDraftInput()
spec_info.hidden_states = self.hidden_states[:num_tokens]
spec_info.positions = positions
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
spec_info.init(self.model_runner.server_args)
else:
spec_info = EagleVerifyInput(
None,
None,
None,
None,
None,
None,
self.model_runner.server_args.speculative_num_draft_tokens,
)
spec_info.custom_mask = torch.zeros(
(num_tokens * self.model_runner.model_config.context_len),
dtype=torch.bool,
device="cuda",
)
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
return spec_info
......@@ -38,6 +38,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import maybe_torch_compile
if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend
......@@ -276,10 +277,21 @@ class ForwardBatch:
)
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
return ret
# Override the positions with spec_info
if (
ret.spec_info is not None
and getattr(ret.spec_info, "positions", None) is not None
):
ret.positions = ret.spec_info.positions
# Init position information
if not ret.forward_mode.is_decode():
if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
else:
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
......@@ -288,13 +300,15 @@ class ForwardBatch:
).to(device, non_blocking=True)
if model_runner.server_args.attention_backend != "torch_native":
ret.extend_num_tokens = batch.extend_num_tokens
ret.positions, ret.extend_start_loc = compute_position_triton(
positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
)
else:
ret.positions, ret.extend_start_loc = compute_position_torch(
positions, ret.extend_start_loc = compute_position_torch(
ret.extend_prefix_lens, ret.extend_seq_lens
)
if ret.positions is None:
ret.positions = positions
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
......@@ -383,6 +397,11 @@ def compute_position_torch(
return positions.to(torch.int64), extend_start_loc
@maybe_torch_compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
class CaptureHiddenMode(IntEnum):
NULL = auto()
FULL = auto()
......
......@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
......@@ -74,6 +75,7 @@ class ModelRunner:
tp_size: int,
nccl_port: int,
server_args: ServerArgs,
is_draft_worker: bool = False,
):
# Parse args
self.model_config = model_config
......@@ -84,8 +86,12 @@ class ModelRunner:
self.tp_size = tp_size
self.dist_port = nccl_port
self.server_args = server_args
self.is_draft_worker = is_draft_worker
self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
# Model-specific adjustment
if (
......@@ -205,14 +211,18 @@ class ModelRunner:
else:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment(
backend=backend,
world_size=self.tp_size,
rank=self.tp_rank,
local_rank=self.gpu_id,
distributed_init_method=dist_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
if not self.is_draft_worker:
# Only initilzie the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
world_size=self.tp_size,
rank=self.tp_rank,
local_rank=self.gpu_id,
distributed_init_method=dist_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1
)
......@@ -407,7 +417,6 @@ class ModelRunner:
target_dtype = (
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
)
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
assert (
self._model_update_group is not None
......@@ -506,6 +515,28 @@ class ModelRunner:
)
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_num_reqs is None:
max_num_reqs = min(
max(
int(
self.max_total_num_tokens / self.model_config.context_len * 512
),
2048,
),
4096,
)
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
else:
self.server_args.draft_runner_cache_size = (
self.max_total_num_tokens
+ max_num_reqs * self.server_args.speculative_num_steps
+ 100
)
if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens:
logging.warning(
......@@ -520,17 +551,6 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static."
)
if max_num_reqs is None:
max_num_reqs = min(
max(
int(
self.max_total_num_tokens / self.model_config.context_len * 512
),
2048,
),
4096,
)
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4,
......@@ -650,10 +670,6 @@ class ModelRunner:
tensor_parallel(self.model, device_mesh)
def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
self.attn_backend.init_forward_metadata(forward_batch)
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
......@@ -683,14 +699,18 @@ class ModelRunner:
)
def forward_idle(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
if (
forward_batch.forward_mode.is_cuda_graph()
and self.cuda_graph_runner
and self.cuda_graph_runner.can_run(forward_batch)
):
return self.cuda_graph_runner.replay(forward_batch)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
......
......@@ -23,6 +23,7 @@ from typing import List, Optional
import torch
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
......@@ -247,6 +248,17 @@ class ServerArgs:
"Overlap scheduler is disabled."
)
# Speculative Decoding
if self.speculative_algorithm == "EAGLE":
self.prefill_only_one_req = True
self.disable_cuda_graph_padding = True
self.disable_radix_cache = True
self.disable_overlap_schedule = True
self.chunked_prefill_size = -1
logger.info(
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
)
# GGUF
if (
self.load_format == "auto" or self.load_format == "gguf"
......
......@@ -2,8 +2,12 @@ from enum import IntEnum, auto
class SpeculativeAlgorithm(IntEnum):
NONE = auto()
EAGLE = auto()
def is_none(self):
return self == SpeculativeAlgorithm.NONE
def is_eagle(self):
return self == SpeculativeAlgorithm.EAGLE
......@@ -11,6 +15,7 @@ class SpeculativeAlgorithm(IntEnum):
def from_string(name: str):
name_map = {
"EAGLE": SpeculativeAlgorithm.EAGLE,
None: SpeculativeAlgorithm.NONE,
}
return name_map[name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment