"vscode:/vscode.git/clone" did not exist on "6ebaf43ee4a6fbbeba685315d605536db1c0c471"
Unverified Commit a3299c3d authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] Misc code simplification (#35941)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 6c21a0c2
...@@ -85,7 +85,6 @@ class UvaBackedTensor: ...@@ -85,7 +85,6 @@ class UvaBackedTensor:
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2 self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
): ):
self.dtype = dtype self.dtype = dtype
self.max_concurrency = max_concurrency
# Source of truth # Source of truth
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False) self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
......
...@@ -96,11 +96,7 @@ logger = init_logger(__name__) ...@@ -96,11 +96,7 @@ logger = init_logger(__name__)
class GPUModelRunner(LoRAModelRunnerMixin): class GPUModelRunner(LoRAModelRunnerMixin):
def __init__( def __init__(self, vllm_config: VllmConfig, device: torch.device):
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
...@@ -627,9 +623,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -627,9 +623,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs, dtype=torch.int32, device=self.device num_reqs, dtype=torch.int32, device=self.device
) )
else: else:
num_draft_tokens = np.array( num_draft_tokens = np.fromiter(
[len(draft_tokens.get(req_id, ())) for req_id in req_ids], (len(draft_tokens.get(req_id, ())) for req_id in req_ids),
dtype=np.int32, dtype=np.int32,
count=num_reqs,
) )
total_num_draft_tokens = int(num_draft_tokens.sum()) total_num_draft_tokens = int(num_draft_tokens.sum())
total_num_logits = num_reqs + total_num_draft_tokens total_num_logits = num_reqs + total_num_draft_tokens
...@@ -782,9 +779,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -782,9 +779,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if input_batch.num_draft_tokens == 0: if input_batch.num_draft_tokens == 0:
# No draft tokens (common case). # No draft tokens (common case).
num_sampled = torch.ones( num_sampled = input_batch.seq_lens.new_ones(input_batch.num_reqs)
input_batch.num_reqs, dtype=torch.int32, device=self.device
)
else: else:
# Rejection sampling for spec decoding. # Rejection sampling for spec decoding.
sampled_tokens, num_sampled = rejection_sample( sampled_tokens, num_sampled = rejection_sample(
......
...@@ -48,17 +48,8 @@ def rejection_sample( ...@@ -48,17 +48,8 @@ def rejection_sample(
num_speculative_steps: int, num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1 num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty( sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
num_reqs, num_sampled = cu_num_logits.new_empty(num_reqs)
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_rejection_sample_kernel[(num_reqs,)]( _rejection_sample_kernel[(num_reqs,)](
sampled, sampled,
sampled.stride(0), sampled.stride(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment