Unverified Commit 463c6632 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Eliminate 2 gpu ops during sampling when logit_bias is zero (#343)


Co-authored-by: default avatarQubitium <417764+Qubitium@users.noreply.github.com>
parent b0890631
......@@ -251,10 +251,14 @@ class Batch:
] = out_cache_loc[pt : pt + extend_lens[i]]
pt += extend_lens[i]
# Handle logit bias
logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
# Handle logit bias but only allocate when needed
logit_bias = None
for i in range(bs):
if reqs[i].sampling_params.dtype == "int":
if logit_bias is None:
logit_bias = torch.zeros(
(bs, vocab_size), dtype=torch.float32, device=device
)
logit_bias[i] = int_token_logit_bias
# Set fields
......@@ -433,9 +437,12 @@ class Batch:
"presence_penalties",
"logit_bias",
]:
setattr(self, item, getattr(self, item)[new_indices])
self_val = getattr(self, item, None)
# logit_bias can be None
if self_val is not None:
setattr(self, item, self_val[new_indices])
def merge(self, other):
def merge(self, other: "Batch"):
self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
......@@ -456,17 +463,34 @@ class Batch:
"top_ks",
"frequency_penalties",
"presence_penalties",
"logit_bias",
]:
setattr(
self, item, torch.concat([getattr(self, item), getattr(other, item)])
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
# logit_bias can be None
if self.logit_bias is not None or other.logit_bias is not None:
vocab_size = (
self.logit_bias.shape[1]
if self.logit_bias is not None
else other.logit_bias.shape[1]
)
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
if other.logit_bias is None:
other.logit_bias = torch.zeros(
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
def sample(self, logits: torch.Tensor):
# Post process logits
logits = logits.contiguous()
logits.div_(self.temperatures)
logits.add_(self.logit_bias)
if self.logit_bias is not None:
logits.add_(self.logit_bias)
has_regex = any(req.regex_fsm is not None for req in self.reqs)
if has_regex:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment