Commit 0c1cd0f5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Revert "Merge remote-tracking branch 'origin/v0.9.2-dev-wm' into v0.9.2-dev"

This reverts merge request !169
parent 0d4ff65d
...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
@support_torch_compile #@support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP): class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -59,9 +56,6 @@ class EagleProposer: ...@@ -59,9 +56,6 @@ class EagleProposer:
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and == CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager) not self.vllm_config.model_config.enforce_eager)
self.use_full_cuda_graph = (
self.use_cuda_graph
and vllm_config.compilation_config.full_cuda_graph)
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed( reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes)) self.vllm_config.compilation_config.cudagraph_capture_sizes))
...@@ -77,9 +71,6 @@ class EagleProposer: ...@@ -77,9 +71,6 @@ class EagleProposer:
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
device=device) device=device)
# attention metadata captured in full cudagraph mode
self.attn_metadata_cudagraph = None
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
...@@ -107,7 +98,7 @@ class EagleProposer: ...@@ -107,7 +98,7 @@ class EagleProposer:
num_rejected_tokens: list[int], num_rejected_tokens: list[int],
# [batch_size] # [batch_size]
sampling_metadata: SamplingMetadata sampling_metadata: SamplingMetadata
) -> tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1 last_token_indices = cu_num_tokens[1:] - 1
...@@ -166,7 +157,7 @@ class EagleProposer: ...@@ -166,7 +157,7 @@ class EagleProposer:
# FIXME: need to consider multiple kv_cache_groups # FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build( attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0, common_prefix_len=0,
common_attn_metadata=common_attn_metadata common_attn_metadata=common_attn_metadata,
) )
else: else:
raise ValueError(f"Unsupported method: {self.method}") raise ValueError(f"Unsupported method: {self.method}")
...@@ -185,38 +176,6 @@ class EagleProposer: ...@@ -185,38 +176,6 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
if (self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
if attn_metadata.decode is not None:
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens): num_tokens=num_input_tokens):
...@@ -233,14 +192,10 @@ class EagleProposer: ...@@ -233,14 +192,10 @@ class EagleProposer:
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list = [draft_prob]
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1: if self.num_speculative_tokens == 1:
# [batch_size, 1] # [batch_size, 1]
return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, draft_prob.shape[-1]) return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has # TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once # one layer. Adapt this code to support multiple layers once
...@@ -275,7 +230,7 @@ class EagleProposer: ...@@ -275,7 +230,7 @@ class EagleProposer:
seq_lens=(seq_lens + 1), seq_lens=(seq_lens + 1),
) )
for i in range(self.num_speculative_tokens - 1): for _ in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default. # tensor.argmax() returns int64 by default.
...@@ -327,43 +282,6 @@ class EagleProposer: ...@@ -327,43 +282,6 @@ class EagleProposer:
self.positions[:batch_size] = clamped_positions self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states self.hidden_states[:batch_size] = hidden_states
if (self.use_full_cuda_graph
and batch_size <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
attn_metadata.slot_mapping)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
1] = (
attn_metadata
.
query_start_loc
)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
# Run the model. # Run the model.
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -387,14 +305,9 @@ class EagleProposer: ...@@ -387,14 +305,9 @@ class EagleProposer:
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids) draft_token_ids_list.append(draft_token_ids)
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_list.append(draft_prob)
# [batch_size, num_speculative_tokens] # [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
draft_probs = torch.stack(draft_probs_list, dim=1).contiguous() return draft_token_ids
return draft_token_ids, draft_probs
@staticmethod @staticmethod
def prepare_inputs( def prepare_inputs(
...@@ -429,7 +342,7 @@ class EagleProposer: ...@@ -429,7 +342,7 @@ class EagleProposer:
) )
batch_size = num_rejected_tokens.shape[0] batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
prepare_eagle_input_kernel[(batch_size, )]( prepare_eagle_input_kernel[(batch_size,)](
token_indices, token_indices,
cu_target_query_lens, cu_target_query_lens,
cu_num_tokens, cu_num_tokens,
...@@ -491,13 +404,8 @@ class EagleProposer: ...@@ -491,13 +404,8 @@ class EagleProposer:
def dummy_run( def dummy_run(
self, self,
num_tokens: int, num_tokens: int,
attn_metadata: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
if attn_metadata is not None and self.attn_metadata_cudagraph is None: with set_forward_context(None, self.vllm_config,
self.attn_metadata_cudagraph = attn_metadata[
self.attn_layer_names[0]]
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
self.model( self.model(
self.input_ids[:num_tokens], self.input_ids[:num_tokens],
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import msgspec
from abc import ABC
import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -43,41 +39,3 @@ def prepare_eagle_input_kernel( ...@@ -43,41 +39,3 @@ def prepare_eagle_input_kernel(
index_start + offset, index_start + offset,
mask=offset < num_tokens, mask=offset < num_tokens,
) )
class DraftProbs(ABC): # type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences."""
# spec tokens probs.
draft_probs: torch.Tensor
# The request id list.
_req_ids: list[str]
def __init__(self, draft_probs, req_ids):
assert len(req_ids) == len(draft_probs)
self.draft_probs = draft_probs
self._req_ids = req_ids
def update(self,
draft_probs: torch.Tensor,
tmp_req_ids: list[str]):
diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids]
index = [self._req_ids.index(req_id) for req_id in diff_req_ids]
self._req_ids = diff_req_ids
self.draft_probs = self.draft_probs[index]
self.draft_probs = torch.cat([self.draft_probs, draft_probs])
self._req_ids.extend(tmp_req_ids)
assert len(self._req_ids) == len(self.draft_probs)
def prune(self, req_ids: list[str]):
new_req_ids = [req_id for req_id in self._req_ids if req_id not in req_ids]
if new_req_ids != self._req_ids:
# Batch contents changed - prune removed sequences.
index = [self._req_ids.index(req_id) for req_id in new_req_ids]
self.draft_probs = self.draft_probs[index]
self._req_ids = new_req_ids
def get_probs(self, req_ids: list[str]):
index = [self._req_ids.index(req_id) for req_id in req_ids]
return self.draft_probs[index]
...@@ -58,13 +58,11 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ...@@ -58,13 +58,11 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.rejection_sampler_mtp import MtpRejectionSampler
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import DraftProbs
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -194,12 +192,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -194,12 +192,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
raise ValueError("Unknown speculative decoding method: " raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}") f"{self.speculative_config.method}")
self.use_mtp = self.speculative_config.method == "deepseek_mtp"
if not self.use_mtp:
self.rejection_sampler = RejectionSampler() self.rejection_sampler = RejectionSampler()
else:
self.rejection_sampler = MtpRejectionSampler()
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
...@@ -326,8 +319,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -326,8 +319,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`. # from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {} self.shared_kv_cache_layers: dict[str, str] = {}
self.draft_probs : Optional[DraftProbs] = None
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
...@@ -387,10 +378,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -387,10 +378,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
# prune draft probs of finished requests
if self.use_mtp and self.draft_probs is not None and len(scheduler_output.finished_req_ids) > 0:
self.draft_probs.prune(list(scheduler_output.finished_req_ids))
# Free the cached encoder outputs. # Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids: for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id) encoder_outputs = self.encoder_cache.get(req_id)
...@@ -548,7 +535,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -548,7 +535,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
spec_token_ids = ( spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
if spec_token_ids: if spec_token_ids:
num_spec_tokens = len(spec_token_ids) num_spec_tokens = len(spec_token_ids)
start_index = self.input_batch.num_tokens_no_spec[req_index] start_index = self.input_batch.num_tokens_no_spec[req_index]
...@@ -1472,8 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1472,8 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_logits = logits[spec_decode_metadata.target_logits_indices] target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler( output_token_ids = self.rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
self.draft_probs.get_probs(self.input_batch.req_ids) \ None, # draft_probs
if self.draft_probs is not None else None, # draft_probs
target_logits, target_logits,
bonus_token_ids, bonus_token_ids,
sampling_metadata, sampling_metadata,
...@@ -1558,7 +1543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1558,7 +1543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
else: else:
spec_token_ids, draft_probs = self.propose_draft_token_ids( spec_token_ids = self.propose_draft_token_ids(
scheduler_output, scheduler_output,
valid_sampled_token_ids, valid_sampled_token_ids,
sampling_metadata, sampling_metadata,
...@@ -1569,15 +1554,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1569,15 +1554,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata, attn_metadata,
) )
if self.use_mtp:
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, self.input_batch.req_ids)
else:
self.draft_probs.update(draft_probs, self.input_batch.req_ids)
spec_token_ids = spec_token_ids.tolist()
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
...@@ -1594,7 +1570,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1594,7 +1570,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pooler_output=[], pooler_output=[],
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits num_nans_in_logits=num_nans_in_logits,
) )
def propose_draft_token_ids( def propose_draft_token_ids(
...@@ -1607,8 +1583,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1607,8 +1583,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states: Optional[torch.Tensor], aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata], spec_decode_metadata: Optional[SpecDecodeMetadata],
attn_metadata: dict[str, Any], attn_metadata: dict[str, Any],
) -> tuple[list[list[int]], torch.Tensor]: ) -> list[list[int]]:
draft_probs = None
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer) assert isinstance(self.drafter, NgramProposer)
...@@ -1707,7 +1682,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1707,7 +1682,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[ target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices] token_indices]
spec_token_ids, draft_probs = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
...@@ -1718,8 +1693,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1718,8 +1693,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
num_rejected_tokens=num_rejected_tokens num_rejected_tokens=num_rejected_tokens
) )
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs return spec_token_ids
def kv_connector_no_forward( def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
...@@ -2108,7 +2083,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2108,7 +2083,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens, attn_metadata) self.drafter.dummy_run(num_tokens)
# This is necessary to avoid blocking DP. # This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real # For dummy runs, we typically skip EPLB since we don't have any real
...@@ -2175,10 +2150,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2175,10 +2150,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids, self.device) draft_token_ids, self.device)
num_tokens = sum(len(ids) for ids in draft_token_ids) num_tokens = sum(len(ids) for ids in draft_token_ids)
draft_probs = torch.randn( # draft_probs = torch.randn(
num_tokens, logits.shape[-1], device=self.device, # num_tokens, logits.shape[-1], device=self.device,
dtype=logits.dtype) # dtype=logits.dtype)
# draft_probs = None draft_probs = None
target_logits = torch.randn(num_tokens, target_logits = torch.randn(num_tokens,
logits.shape[-1], logits.shape[-1],
device=self.device, device=self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment