Unverified Commit a3299c3d authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] Misc code simplification (#35941)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 6c21a0c2
......@@ -85,7 +85,6 @@ class UvaBackedTensor:
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
):
self.dtype = dtype
self.max_concurrency = max_concurrency
# Source of truth
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
......
......@@ -96,11 +96,7 @@ logger = init_logger(__name__)
class GPUModelRunner(LoRAModelRunnerMixin):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
......@@ -627,9 +623,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs, dtype=torch.int32, device=self.device
)
else:
num_draft_tokens = np.array(
[len(draft_tokens.get(req_id, ())) for req_id in req_ids],
num_draft_tokens = np.fromiter(
(len(draft_tokens.get(req_id, ())) for req_id in req_ids),
dtype=np.int32,
count=num_reqs,
)
total_num_draft_tokens = int(num_draft_tokens.sum())
total_num_logits = num_reqs + total_num_draft_tokens
......@@ -782,9 +779,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
num_sampled = torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
)
num_sampled = input_batch.seq_lens.new_ones(input_batch.num_reqs)
else:
# Rejection sampling for spec decoding.
sampled_tokens, num_sampled = rejection_sample(
......
......@@ -48,17 +48,8 @@ def rejection_sample(
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
num_sampled = cu_num_logits.new_empty(num_reqs)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment