Unverified Commit 2bcfba1b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Skip unnecessary penalizer (#1707)

parent bc12d403
...@@ -515,11 +515,11 @@ class ScheduleBatch: ...@@ -515,11 +515,11 @@ class ScheduleBatch:
assert seq_len - pre_len == req.extend_input_len assert seq_len - pre_len == req.extend_input_len
if pre_len > 0: if pre_len > 0:
self.req_to_token_pool.req_to_token[req.req_pool_idx][ self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
:pre_len req.prefix_indices
] = req.prefix_indices )
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len] out_cache_loc[pt : pt + req.extend_input_len]
) )
...@@ -535,10 +535,15 @@ class ScheduleBatch: ...@@ -535,10 +535,15 @@ class ScheduleBatch:
pt += req.extend_input_len pt += req.extend_input_len
# Set fields # Set fields
with out_cache_loc.device: self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) self.device, non_blocking=True
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32) )
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32) self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
...@@ -782,8 +787,8 @@ class ScheduleBatch: ...@@ -782,8 +787,8 @@ class ScheduleBatch:
return return
self.reqs = [self.reqs[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor( new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
keep_indices, dtype=torch.int32, device=self.seq_lens.device self.device, non_blocking=True
) )
self.req_pool_indices = self.req_pool_indices[new_indices] self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices] self.seq_lens = self.seq_lens[new_indices]
......
...@@ -150,6 +150,7 @@ class Scheduler: ...@@ -150,6 +150,7 @@ class Scheduler:
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
) )
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.device = self.tp_worker.device
# Get token and memory info from the model worker # Get token and memory info from the model worker
( (
...@@ -758,9 +759,7 @@ class Scheduler: ...@@ -758,9 +759,7 @@ class Scheduler:
if logits_output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = ( logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[ logits_output.next_token_logprobs[
torch.arange( torch.arange(len(next_token_ids), device=self.device),
len(next_token_ids), device=next_token_ids.device
),
next_token_ids, next_token_ids,
].tolist() ].tolist()
) )
...@@ -828,7 +827,7 @@ class Scheduler: ...@@ -828,7 +827,7 @@ class Scheduler:
# Move logprobs to cpu # Move logprobs to cpu
if batch.return_logprob: if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[ next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device), torch.arange(len(next_token_ids), device=self.device),
next_token_ids, next_token_ids,
].tolist() ].tolist()
......
...@@ -90,7 +90,7 @@ class BaseTokenToKVPool: ...@@ -90,7 +90,7 @@ class BaseTokenToKVPool:
select_index = self.free_slots[:need_size] select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:] self.free_slots = self.free_slots[need_size:]
return select_index.to(self.device) return select_index.to(self.device, non_blocking=True)
def free(self, free_index: torch.Tensor): def free(self, free_index: torch.Tensor):
if self.is_not_in_free_group: if self.is_not_in_free_group:
......
...@@ -135,25 +135,22 @@ class ForwardBatch: ...@@ -135,25 +135,22 @@ class ForwardBatch:
# Init position information # Init position information
if not ret.forward_mode.is_decode(): if not ret.forward_mode.is_decode():
ret.positions = torch.tensor( ret.positions = torch.concat(
np.concatenate( [
[ torch.arange(prefix_len, prefix_len + extend_len, device=device)
np.arange(prefix_len, prefix_len + extend_len) for prefix_len, extend_len in zip(
for prefix_len, extend_len in zip( batch.extend_prefix_lens, batch.extend_seq_lens
batch.extend_prefix_lens, batch.extend_seq_lens )
) ],
], axis=0,
axis=0,
),
dtype=torch.int64,
device=device,
) )
ret.image_inputs = batch.image_inputs ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device) ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor( ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, device=device batch.extend_prefix_lens, dtype=torch.int32
) ).to(device, non_blocking=True)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens) ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0) ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_seq_lens_cpu = batch.extend_seq_lens
......
...@@ -37,12 +37,16 @@ class BatchedPenalizerOrchestrator: ...@@ -37,12 +37,16 @@ class BatchedPenalizerOrchestrator:
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
is_required = False
for penalizer in self.penalizers.values(): for penalizer in self.penalizers.values():
penalizer.prepare_if_required() pen_is_required = penalizer.prepare_if_required()
is_required |= pen_is_required
self.is_required = is_required
self.cumulate_input_tokens( if self.is_required:
input_ids=[req.origin_input_ids for req in self.reqs()] self.cumulate_input_tokens(
) input_ids=[req.origin_input_ids for req in self.reqs()]
)
def reqs(self): def reqs(self):
return self.batch.reqs return self.batch.reqs
...@@ -79,6 +83,9 @@ class BatchedPenalizerOrchestrator: ...@@ -79,6 +83,9 @@ class BatchedPenalizerOrchestrator:
Args: Args:
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens. output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
""" """
if not self.is_required:
return
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids) token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
for penalizer in self.penalizers.values(): for penalizer in self.penalizers.values():
...@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator: ...@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator:
Returns: Returns:
torch.Tensor: The logits after applying the penalizers. torch.Tensor: The logits after applying the penalizers.
""" """
if not self.is_required:
return
for penalizer in self.penalizers.values(): for penalizer in self.penalizers.values():
logits = penalizer.apply(logits) logits = penalizer.apply(logits)
...@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator: ...@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator:
indices_to_keep (typing.List[int]): List of indices to keep in the batch. indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor. indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
""" """
if not self.is_required:
return
empty_indices = len(indices_to_keep) == 0 empty_indices = len(indices_to_keep) == 0
is_required = False
for penalizer in self.penalizers.values(): for penalizer in self.penalizers.values():
if not penalizer.is_required() or empty_indices: tmp_is_required = penalizer.is_required()
is_required = is_required or tmp_is_required
if not tmp_is_required or empty_indices:
penalizer.teardown() penalizer.teardown()
else: else:
# create tensor index only when it's needed # create tensor index only when it's needed
...@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator: ...@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator:
indices_to_keep=indices_to_keep, indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep, indices_tensor_to_keep=indices_tensor_to_keep,
) )
self.is_required = is_required
def merge(self, their: "BatchedPenalizerOrchestrator"): def merge(self, their: "BatchedPenalizerOrchestrator"):
""" """
...@@ -140,11 +157,10 @@ class BatchedPenalizerOrchestrator: ...@@ -140,11 +157,10 @@ class BatchedPenalizerOrchestrator:
Args: Args:
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one. their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
""" """
if self.vocab_size != their.vocab_size: if not self.is_required and not their.is_required:
raise ValueError( return
f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
)
self.is_required |= their.is_required
for Penalizer, their_penalizer in their.penalizers.items(): for Penalizer, their_penalizer in their.penalizers.items():
if Penalizer not in self.penalizers: if Penalizer not in self.penalizers:
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers") raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
...@@ -250,6 +266,9 @@ class _BatchedPenalizer(abc.ABC): ...@@ -250,6 +266,9 @@ class _BatchedPenalizer(abc.ABC):
def prepare_if_required(self): def prepare_if_required(self):
if self.is_required(): if self.is_required():
self.prepare() self.prepare()
return True
else:
return False
def teardown(self): def teardown(self):
if self.is_prepared(): if self.is_prepared():
......
...@@ -48,20 +48,24 @@ class SamplingBatchInfo: ...@@ -48,20 +48,24 @@ class SamplingBatchInfo:
disable_penalizer: bool, disable_penalizer: bool,
): ):
reqs = batch.reqs reqs = batch.reqs
with batch.input_ids.device: device = batch.input_ids.device
temperatures = torch.tensor( temperatures = (
torch.tensor(
[r.sampling_params.temperature for r in reqs], [r.sampling_params.temperature for r in reqs],
dtype=torch.float, dtype=torch.float,
).view(-1, 1)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
) )
.view(-1, 1)
.to(device, non_blocking=True)
)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
).to(device, non_blocking=True)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
ret = cls( ret = cls(
temperatures=temperatures, temperatures=temperatures,
...@@ -80,7 +84,7 @@ class SamplingBatchInfo: ...@@ -80,7 +84,7 @@ class SamplingBatchInfo:
# #
# While we choose not to even create the class instances if they are not required, this # While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to # could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well. # handle {filter_batch()} and {merge_batch()} cases as well.
if disable_penalizer: if disable_penalizer:
ret.penalizer_orchestrator = None ret.penalizer_orchestrator = None
else: else:
...@@ -112,19 +116,20 @@ class SamplingBatchInfo: ...@@ -112,19 +116,20 @@ class SamplingBatchInfo:
self.linear_penalties = None self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values(): for penalizer in self.penalizer_orchestrator.penalizers.values():
if not penalizer.is_prepared():
continue
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer): if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
if penalizer.is_prepared(): self.scaling_penalties = penalizer.cumulated_repetition_penalties
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else: else:
if penalizer.is_prepared(): if self.linear_penalties is None:
if self.linear_penalties is None: bs = self.penalizer_orchestrator.batch.batch_size()
bs = self.penalizer_orchestrator.batch.batch_size() self.linear_penalties = torch.zeros(
self.linear_penalties = torch.zeros( (bs, self.vocab_size),
(bs, self.vocab_size), dtype=torch.float32,
dtype=torch.float32, device=self.device,
device=self.device, )
) self.linear_penalties = penalizer.apply(self.linear_penalties)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self): def update_regex_vocab_mask(self):
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
......
...@@ -164,19 +164,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase): ...@@ -164,19 +164,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
) )
actual = orchestrator.apply( original = torch.ones(
torch.ones( size=(len(case.test_subjects), self.vocab_size),
size=(len(case.test_subjects), self.vocab_size), dtype=torch.float32,
dtype=torch.float32, device=self.device,
device=self.device,
)
) )
actual = orchestrator.apply(original.clone())
expected = torch.cat( expected = torch.cat(
tensors=[ tensors=[
subject.steps[0].expected_logits subject.steps[0].expected_logits
for subject in case.test_subjects for subject in case.test_subjects
], ],
) )
if actual is None:
actual = original
torch.testing.assert_close( torch.testing.assert_close(
actual=actual, actual=actual,
expected=expected, expected=expected,
...@@ -226,6 +227,8 @@ class BaseBatchedPenalizerTest(unittest.TestCase): ...@@ -226,6 +227,8 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
device=self.device, device=self.device,
) )
) )
if actual_logits is None:
continue
filtered_expected_logits = torch.cat( filtered_expected_logits = torch.cat(
tensors=[ tensors=[
subject.steps[0].expected_logits subject.steps[0].expected_logits
...@@ -317,19 +320,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase): ...@@ -317,19 +320,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}", msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
) )
actual_logits = orchestrator.apply( original = torch.ones(
torch.ones( size=(len(filtered_subjects), self.vocab_size),
size=(len(filtered_subjects), self.vocab_size), dtype=torch.float32,
dtype=torch.float32, device=self.device,
device=self.device,
)
) )
actual_logits = orchestrator.apply(original.clone())
filtered_expected_logits = torch.cat( filtered_expected_logits = torch.cat(
tensors=[ tensors=[
subject.steps[i].expected_logits subject.steps[i].expected_logits
for subject in filtered_subjects for subject in filtered_subjects
], ],
) )
if actual_logits is None:
actual_logits = original
torch.testing.assert_close( torch.testing.assert_close(
actual=actual_logits, actual=actual_logits,
expected=filtered_expected_logits, expected=filtered_expected_logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment