Commit 2fc5b0bb authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev

parents b550cf96 48742057
......@@ -87,12 +87,8 @@ class OptRejectionSampler(nn.Module):
assert metadata.max_spec_len <= MAX_SPEC_LEN
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
draft_token_ids = metadata.draft_token_ids
mask = draft_token_ids.eq(-1).to(torch.bool)
draft_token_ids = torch.where(mask, 0, draft_token_ids).to(torch.long) # 兼容第一次decode
output_token_ids = rejection_sample(
draft_token_ids,
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
......@@ -225,6 +221,8 @@ def rejection_random_sample_kernel(
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if draft_token_id < 0:
draft_token_id = 0
if NO_DRAFT_PROBS:
draft_prob = 1
else:
......@@ -235,6 +233,7 @@ def rejection_random_sample_kernel(
(start_idx + pos) * vocab_size +
draft_token_id)
draft_token_id = draft_token_id.to(tl.int64)
target_token_id = tl.load(target_token_ids_ptr + (start_idx + pos))
target_token_id = target_token_id.to(tl.int64)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
......
......@@ -6,6 +6,7 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.utils import async_tensor_h2d
_SAMPLING_EPS = 1e-5
......@@ -80,5 +81,10 @@ class DraftProbs(ABC): # type: ignore[call-arg]
def get_probs(self, req_ids: list[str]):
index = [self._req_ids.index(req_id) for req_id in req_ids]
return self.draft_probs[index]
index_tensor = async_tensor_h2d(
index,
dtype=torch.int32,
target_device=self.draft_probs.device,
pin_memory=True)
return self.draft_probs[index_tensor]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment