Unverified Commit 9ed6ee92 authored by Bryan Lu's avatar Bryan Lu Committed by GitHub
Browse files

[Bugfix] EAGLE output norm bug (#14464)


Signed-off-by: default avatarBryan Lu <yuzhelu@amazon.com>
parent ee3778d5
...@@ -162,7 +162,7 @@ A variety of speculative models of this type are available on HF hub: ...@@ -162,7 +162,7 @@ A variety of speculative models of this type are available on HF hub:
## Speculating using EAGLE based draft models ## Speculating using EAGLE based draft models
The following code configures vLLM to use speculative decoding where proposals are generated by The following code configures vLLM to use speculative decoding where proposals are generated by
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](<gh-file:examples/offline_inference/eagle.py>).
```python ```python
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
......
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
import os
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="./examples/data/gsm8k.jsonl",
help="downloaded from the eagle repo " \
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
)
parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--draft_tp", type=int, default=1)
parser.add_argument("--enforce_eager", action='store_true')
parser.add_argument("--enable_chunked_prefill", action='store_true')
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)
args = parser.parse_args()
print(args)
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
max_model_len = 2048
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if os.path.exists(args.dataset):
prompts = []
num_prompts = args.num_prompts
with open(args.dataset) as f:
for line in f:
data = json.loads(line)
prompts.append(data["turns"][0])
else:
prompts = ["The future of AI is", "The president of the United States is"]
prompts = prompts[:args.num_prompts]
num_prompts = len(prompts)
prompt_ids = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True)
for prompt in prompts
]
llm = LLM(
model=model_dir,
trust_remote_code=True,
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
max_num_batched_tokens=args.max_num_batched_tokens,
enforce_eager=args.enforce_eager,
max_model_len=max_model_len,
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_model=eagle_dir,
num_speculative_tokens=args.num_spec_tokens,
speculative_draft_tensor_parallel_size=args.draft_tp,
speculative_max_model_len=max_model_len,
disable_log_stats=False,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)
# calculate the average number of accepted tokens per forward pass, +1 is
# to account for the token from the target model that's always going to be
# accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs:
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count
print(f"mean acceptance length: \
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
...@@ -853,6 +853,10 @@ class LLMEngine: ...@@ -853,6 +853,10 @@ class LLMEngine:
self.generation_config_fields, seq.eos_token_id) self.generation_config_fields, seq.eos_token_id)
# Create the sequence group. # Create the sequence group.
draft_size = 1
if self.vllm_config.speculative_config is not None:
draft_size = \
self.vllm_config.speculative_config.num_speculative_tokens + 1
seq_group = SequenceGroup( seq_group = SequenceGroup(
request_id=request_id, request_id=request_id,
seqs=[seq], seqs=[seq],
...@@ -862,7 +866,8 @@ class LLMEngine: ...@@ -862,7 +866,8 @@ class LLMEngine:
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority,
draft_size=draft_size)
return seq_group return seq_group
......
...@@ -100,6 +100,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -100,6 +100,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
seqs = sequence_group.get_seqs( seqs = sequence_group.get_seqs(
status=SequenceStatus.FINISHED_ABORTED) status=SequenceStatus.FINISHED_ABORTED)
for output in outputs:
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
sequence_group.metrics.spec_token_acceptance_counts[
output.step_index] += 1
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
assert len(seqs) == 1, ( assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.") "Beam search not supported in multi-step decoding.")
......
...@@ -38,7 +38,7 @@ class DummyOutputNorm(nn.Module): ...@@ -38,7 +38,7 @@ class DummyOutputNorm(nn.Module):
if residual is None: if residual is None:
return x return x
else: else:
return x, residual return x + residual, None
class EAGLE(nn.Module): class EAGLE(nn.Module):
......
...@@ -111,6 +111,13 @@ class RequestMetrics: ...@@ -111,6 +111,13 @@ class RequestMetrics:
model_execute_time: The time spent in the model execute function. This model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time. workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from
the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
""" """
arrival_time: float arrival_time: float
last_token_time: float last_token_time: float
...@@ -121,6 +128,7 @@ class RequestMetrics: ...@@ -121,6 +128,7 @@ class RequestMetrics:
scheduler_time: Optional[float] = None scheduler_time: Optional[float] = None
model_forward_time: Optional[float] = None model_forward_time: Optional[float] = None
model_execute_time: Optional[float] = None model_execute_time: Optional[float] = None
spec_token_acceptance_counts: Optional[list[int]] = None
class SequenceDataDelta( class SequenceDataDelta(
...@@ -639,22 +647,25 @@ class SequenceGroup: ...@@ -639,22 +647,25 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request. priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate
for single-draft speculative decoding but larger than
that for multi-draft SD (currently not supported).
""" """
def __init__( def __init__(self,
self, request_id: str,
request_id: str, seqs: list[Sequence],
seqs: list[Sequence], arrival_time: float,
arrival_time: float, sampling_params: Optional[SamplingParams] = None,
sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None,
lora_request: Optional[LoRARequest] = None, pooling_params: Optional[PoolingParams] = None,
pooling_params: Optional[PoolingParams] = None, pooled_data: Optional[torch.Tensor] = None,
pooled_data: Optional[torch.Tensor] = None, encoder_seq: Optional[Sequence] = None,
encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0,
priority: int = 0, draft_size: int = 1) -> None:
) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs = seqs self.seqs = seqs
self.first_seq = seqs[0] self.first_seq = seqs[0]
...@@ -667,7 +678,9 @@ class SequenceGroup: ...@@ -667,7 +678,9 @@ class SequenceGroup:
last_token_time=arrival_time, last_token_time=arrival_time,
first_scheduled_time=None, first_scheduled_time=None,
first_token_time=None, first_token_time=None,
time_in_queue=None) time_in_queue=None,
spec_token_acceptance_counts=[0] *
draft_size)
self.last_token_latency = 0.0 self.last_token_latency = 0.0
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
...@@ -1079,6 +1092,7 @@ class CompletionSequenceGroupOutput( ...@@ -1079,6 +1092,7 @@ class CompletionSequenceGroupOutput(
samples: list[SequenceOutput] samples: list[SequenceOutput]
# Prompt logprob for each prompt query token. # Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs] prompt_logprobs: Optional[PromptLogprobs]
step_index: Optional[int] = 0
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"CompletionSequenceGroupOutput(samples={self.samples}, " return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
......
...@@ -1080,7 +1080,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -1080,7 +1080,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
[sequence_index][:num_logprobs], [sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index] topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs], [sequence_index][:num_logprobs],
)) step_index=step_index))
sampler_output_list.append( sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids)) SamplerOutput(outputs=step_output_token_ids))
......
...@@ -93,14 +93,14 @@ def create_logprobs_output( ...@@ -93,14 +93,14 @@ def create_logprobs_output(
def create_sequence_group_output( def create_sequence_group_output(
token_id: int, token_id: int,
token_id_logprob_rank: int, token_id_logprob_rank: int,
token_id_logprob: float, token_id_logprob: float,
seq_id: SeqId, seq_id: SeqId,
topk_token_ids: List[Optional[int]], topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]], topk_logprobs: List[Optional[float]],
prompt_logprobs: Optional[PromptLogprobs] = None, prompt_logprobs: Optional[PromptLogprobs] = None,
) -> CompletionSequenceGroupOutput: step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results. """Create a SequenceGroupOutput given the sampling results.
Args: Args:
...@@ -110,6 +110,7 @@ def create_sequence_group_output( ...@@ -110,6 +110,7 @@ def create_sequence_group_output(
seq_id (int): The sequence id. seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids. topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs. topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
step_index: (Optional[int]): The index of the speculative token.
""" """
logprobs = create_logprobs_output( logprobs = create_logprobs_output(
...@@ -120,14 +121,13 @@ def create_sequence_group_output( ...@@ -120,14 +121,13 @@ def create_sequence_group_output(
topk_logprobs, topk_logprobs,
) )
return CompletionSequenceGroupOutput( return CompletionSequenceGroupOutput(samples=[
samples=[ SequenceOutput(parent_seq_id=seq_id,
SequenceOutput(parent_seq_id=seq_id, output_token=token_id,
output_token=token_id, logprobs=logprobs)
logprobs=logprobs) ],
], prompt_logprobs=prompt_logprobs,
prompt_logprobs=prompt_logprobs, step_index=step_index)
)
def split_batch_by_proposal_len( def split_batch_by_proposal_len(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment