Unverified Commit d00df624 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Minor refactoring for penalties (#34662)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 9752da9d
...@@ -155,9 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -155,9 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
device=self.device, device=self.device,
all_token_ids=self.req_states.all_token_ids.gpu, req_states=self.req_states,
prompt_len=self.req_states.prompt_len.gpu,
total_len=self.req_states.total_len.gpu,
logprobs_mode=self.model_config.logprobs_mode, logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1, num_speculative_tokens=self.num_speculative_steps + 1,
) )
...@@ -528,11 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -528,11 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if scheduler_output.scheduled_new_reqs: if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes() self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes( self.sampler.apply_staged_writes()
self.req_states.all_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len.np,
)
if self.uses_mrope: if self.uses_mrope:
self.mrope_states.apply_staged_writes() self.mrope_states.apply_staged_writes()
......
...@@ -6,24 +6,17 @@ import torch ...@@ -6,24 +6,17 @@ import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
from vllm.v1.worker.gpu.states import RequestState
MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request
MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request
class BadWordsState: class BadWordsState:
def __init__( def __init__(self, req_states: RequestState):
self, self.req_states = req_states
all_token_ids: torch.Tensor, self.max_num_reqs = req_states.max_num_reqs
prompt_len: torch.Tensor, self.device = req_states.device
total_len: torch.Tensor,
):
self.all_token_ids = all_token_ids
self.prompt_len = prompt_len
self.total_len = total_len
self.max_num_reqs = prompt_len.shape[0]
self.device = prompt_len.device
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS] # flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
self.bad_word_token_ids = StagedWriteTensor( self.bad_word_token_ids = StagedWriteTensor(
...@@ -95,9 +88,9 @@ class BadWordsState: ...@@ -95,9 +88,9 @@ class BadWordsState:
self.bad_word_token_ids.gpu, self.bad_word_token_ids.gpu,
self.bad_word_offsets.gpu, self.bad_word_offsets.gpu,
self.num_bad_words.gpu, self.num_bad_words.gpu,
self.all_token_ids, self.req_states.all_token_ids.gpu,
self.prompt_len, self.req_states.prompt_len.gpu,
self.total_len, self.req_states.total_len.gpu,
input_ids, input_ids,
expanded_local_pos, expanded_local_pos,
max_num_bad_words, max_num_bad_words,
......
...@@ -6,14 +6,18 @@ import torch ...@@ -6,14 +6,18 @@ import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import async_tensor_h2d
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
from vllm.v1.worker.gpu.states import RequestState
class PenaltiesState: class PenaltiesState:
def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device): def __init__(self, req_states: RequestState):
self.max_num_reqs = max_num_reqs self.req_states = req_states
self.vocab_size = vocab_size
self.device = device max_num_reqs = req_states.max_num_reqs
self.vocab_size = req_states.vocab_size
self.device = req_states.device
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
...@@ -26,7 +30,7 @@ class PenaltiesState: ...@@ -26,7 +30,7 @@ class PenaltiesState:
# Statistics for penalties. # Statistics for penalties.
self.prompt_bin_mask = torch.zeros( self.prompt_bin_mask = torch.zeros(
self.max_num_reqs, max_num_reqs,
cdiv(self.vocab_size, 32), cdiv(self.vocab_size, 32),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
...@@ -34,10 +38,10 @@ class PenaltiesState: ...@@ -34,10 +38,10 @@ class PenaltiesState:
# TODO(woosuk): This tensor is rarely used but can be very large, taking up # TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage. # GBs of GPU memory. Optimize the memory usage.
self.output_bin_counts = torch.zeros( self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
) )
self._penalties_reqs: list[int] = [] self._new_penalties_reqs: list[int] = []
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None: def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
...@@ -47,24 +51,29 @@ class PenaltiesState: ...@@ -47,24 +51,29 @@ class PenaltiesState:
do_penalty = use_penalty(sampling_params) do_penalty = use_penalty(sampling_params)
self.use_penalty[req_idx] = do_penalty self.use_penalty[req_idx] = do_penalty
if do_penalty: if do_penalty:
self._penalties_reqs.append(req_idx) self._new_penalties_reqs.append(req_idx)
def apply_staged_writes( def apply_staged_writes(self) -> None:
self, if self._new_penalties_reqs:
all_token_ids: torch.Tensor, idx_mapping = async_tensor_h2d(
prefill_lens: np.ndarray, self._new_penalties_reqs,
prompt_lens: np.ndarray, dtype=torch.int32,
) -> None: target_device=self.device,
# TODO(woosuk): Optimize this. pin_memory=True,
for req_idx in self._penalties_reqs: )
prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs]
max_prefill_len = int(prefill_lens.max())
bincount( bincount(
all_token_ids[req_idx], idx_mapping,
int(prefill_lens[req_idx]), self.req_states.all_token_ids.gpu,
int(prompt_lens[req_idx]), self.req_states.prompt_len.gpu,
self.prompt_bin_mask[req_idx], self.req_states.prefill_len.gpu,
self.output_bin_counts[req_idx], self.prompt_bin_mask,
self.output_bin_counts,
max_prefill_len,
) )
self._penalties_reqs.clear() self._new_penalties_reqs.clear()
self.repetition_penalty.copy_to_uva() self.repetition_penalty.copy_to_uva()
self.frequency_penalty.copy_to_uva() self.frequency_penalty.copy_to_uva()
...@@ -214,51 +223,82 @@ def apply_penalties( ...@@ -214,51 +223,82 @@ def apply_penalties(
) )
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"]) @triton.jit
def _bincount_kernel( def _bincount_kernel(
idx_mapping_ptr,
all_token_ids_ptr, all_token_ids_ptr,
prefill_len, all_token_ids_stride,
prompt_len, prompt_len_ptr,
prefill_len_ptr,
prompt_bin_mask_ptr, prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr, output_bin_counts_ptr,
output_bin_counts_stride,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
block_idx = tl.program_id(0) batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if block_idx * BLOCK_SIZE >= prefill_len: if block_idx * BLOCK_SIZE >= prefill_len:
return return
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if block_idx * BLOCK_SIZE < prompt_len: if block_idx * BLOCK_SIZE < prompt_len:
mask = block < prompt_len mask = block < prompt_len
prompt_tokens = tl.load(all_token_ids_ptr + block, mask=mask) prompt_tokens = tl.load(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
)
idx = prompt_tokens // 32 idx = prompt_tokens // 32
bit_idx = prompt_tokens % 32 bit_idx = prompt_tokens % 32
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask) tl.atomic_or(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx,
bit,
mask=mask,
)
if (block_idx + 1) * BLOCK_SIZE >= prompt_len: if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
mask = block < prefill_len mask = block < prefill_len
mask &= block >= prompt_len mask &= block >= prompt_len
output_tokens = tl.load(all_token_ids_ptr + block, mask=mask) output_tokens = tl.load(
tl.atomic_add(output_bin_counts_ptr + output_tokens, 1, mask=mask) all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
)
tl.atomic_add(
output_bin_counts_ptr
+ req_state_idx * output_bin_counts_stride
+ output_tokens,
1,
mask=mask,
)
def bincount( def bincount(
idx_mapping: torch.Tensor,
all_token_ids: torch.Tensor, all_token_ids: torch.Tensor,
prefill_len: int, prompt_len: torch.Tensor,
prompt_len: int, prefill_len: torch.Tensor,
prompt_bin_mask: torch.Tensor, prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor, output_bin_counts: torch.Tensor,
max_prefill_len: int,
) -> None: ) -> None:
prompt_bin_mask.zero_() prompt_bin_mask[idx_mapping] = 0
output_bin_counts.zero_() output_bin_counts[idx_mapping] = 0
num_reqs = idx_mapping.shape[0]
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE) num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_blocks,)]( _bincount_kernel[(num_reqs, num_blocks)](
idx_mapping,
all_token_ids, all_token_ids,
prefill_len, all_token_ids.stride(0),
prompt_len, prompt_len,
prefill_len,
prompt_bin_mask, prompt_bin_mask,
prompt_bin_mask.stride(0),
output_bin_counts, output_bin_counts,
output_bin_counts.stride(0),
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
) )
......
...@@ -15,6 +15,7 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs ...@@ -15,6 +15,7 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
from vllm.v1.worker.gpu.states import RequestState
class Sampler: class Sampler:
...@@ -23,9 +24,7 @@ class Sampler: ...@@ -23,9 +24,7 @@ class Sampler:
max_num_reqs: int, max_num_reqs: int,
vocab_size: int, vocab_size: int,
device: torch.device, device: torch.device,
all_token_ids: torch.Tensor, req_states: RequestState,
prompt_len: torch.Tensor,
total_len: torch.Tensor,
logprobs_mode: LogprobsMode = "raw_logprobs", logprobs_mode: LogprobsMode = "raw_logprobs",
num_speculative_tokens: int = 1, num_speculative_tokens: int = 1,
): ):
...@@ -35,9 +34,9 @@ class Sampler: ...@@ -35,9 +34,9 @@ class Sampler:
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default. self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
self.sampling_states = SamplingStates(max_num_reqs, vocab_size) self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device) self.penalties_state = PenaltiesState(req_states)
self.logit_bias_state = LogitBiasState(max_num_reqs, device) self.logit_bias_state = LogitBiasState(max_num_reqs, device)
self.bad_words_state = BadWordsState(all_token_ids, prompt_len, total_len) self.bad_words_state = BadWordsState(req_states)
self.num_speculative_tokens = num_speculative_tokens self.num_speculative_tokens = num_speculative_tokens
def add_request( def add_request(
...@@ -48,16 +47,9 @@ class Sampler: ...@@ -48,16 +47,9 @@ class Sampler:
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params) self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
self.bad_words_state.add_request(req_idx, sampling_params) self.bad_words_state.add_request(req_idx, sampling_params)
def apply_staged_writes( def apply_staged_writes(self) -> None:
self,
all_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
self.sampling_states.apply_staged_writes() self.sampling_states.apply_staged_writes()
self.penalties_state.apply_staged_writes( self.penalties_state.apply_staged_writes()
all_token_ids, prefill_lens, prompt_lens
)
self.logit_bias_state.apply_staged_writes() self.logit_bias_state.apply_staged_writes()
self.bad_words_state.apply_staged_writes() self.bad_words_state.apply_staged_writes()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment