"src/vscode:/vscode.git/clone" did not exist on "8f2253c58cf91e322615c0b7fbf2686bc61e71a0"
Unverified Commit 463c6632 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Eliminate 2 gpu ops during sampling when logit_bias is zero (#343)


Co-authored-by: default avatarQubitium <417764+Qubitium@users.noreply.github.com>
parent b0890631
...@@ -251,10 +251,14 @@ class Batch: ...@@ -251,10 +251,14 @@ class Batch:
] = out_cache_loc[pt : pt + extend_lens[i]] ] = out_cache_loc[pt : pt + extend_lens[i]]
pt += extend_lens[i] pt += extend_lens[i]
# Handle logit bias # Handle logit bias but only allocate when needed
logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device) logit_bias = None
for i in range(bs): for i in range(bs):
if reqs[i].sampling_params.dtype == "int": if reqs[i].sampling_params.dtype == "int":
if logit_bias is None:
logit_bias = torch.zeros(
(bs, vocab_size), dtype=torch.float32, device=device
)
logit_bias[i] = int_token_logit_bias logit_bias[i] = int_token_logit_bias
# Set fields # Set fields
...@@ -433,9 +437,12 @@ class Batch: ...@@ -433,9 +437,12 @@ class Batch:
"presence_penalties", "presence_penalties",
"logit_bias", "logit_bias",
]: ]:
setattr(self, item, getattr(self, item)[new_indices]) self_val = getattr(self, item, None)
# logit_bias can be None
if self_val is not None:
setattr(self, item, self_val[new_indices])
def merge(self, other): def merge(self, other: "Batch"):
self.reqs.extend(other.reqs) self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat( self.req_pool_indices = torch.concat(
...@@ -456,17 +463,34 @@ class Batch: ...@@ -456,17 +463,34 @@ class Batch:
"top_ks", "top_ks",
"frequency_penalties", "frequency_penalties",
"presence_penalties", "presence_penalties",
"logit_bias",
]: ]:
setattr( self_val = getattr(self, item, None)
self, item, torch.concat([getattr(self, item), getattr(other, item)]) other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
# logit_bias can be None
if self.logit_bias is not None or other.logit_bias is not None:
vocab_size = (
self.logit_bias.shape[1]
if self.logit_bias is not None
else other.logit_bias.shape[1]
) )
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
if other.logit_bias is None:
other.logit_bias = torch.zeros(
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
def sample(self, logits: torch.Tensor): def sample(self, logits: torch.Tensor):
# Post process logits # Post process logits
logits = logits.contiguous() logits = logits.contiguous()
logits.div_(self.temperatures) logits.div_(self.temperatures)
logits.add_(self.logit_bias) if self.logit_bias is not None:
logits.add_(self.logit_bias)
has_regex = any(req.regex_fsm is not None for req in self.reqs) has_regex = any(req.regex_fsm is not None for req in self.reqs)
if has_regex: if has_regex:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment