"tests/vscode:/vscode.git/clone" did not exist on "bce51cbd65b308360e5c83ed0f773100ec9fda79"
Unverified Commit ed27a6b9 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Revert "Eliminate 2 gpu ops during sampling when logit_bias is zero" (#345)

parent 463c6632
...@@ -251,14 +251,10 @@ class Batch: ...@@ -251,14 +251,10 @@ class Batch:
] = out_cache_loc[pt : pt + extend_lens[i]] ] = out_cache_loc[pt : pt + extend_lens[i]]
pt += extend_lens[i] pt += extend_lens[i]
# Handle logit bias but only allocate when needed # Handle logit bias
logit_bias = None logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
for i in range(bs): for i in range(bs):
if reqs[i].sampling_params.dtype == "int": if reqs[i].sampling_params.dtype == "int":
if logit_bias is None:
logit_bias = torch.zeros(
(bs, vocab_size), dtype=torch.float32, device=device
)
logit_bias[i] = int_token_logit_bias logit_bias[i] = int_token_logit_bias
# Set fields # Set fields
...@@ -437,12 +433,9 @@ class Batch: ...@@ -437,12 +433,9 @@ class Batch:
"presence_penalties", "presence_penalties",
"logit_bias", "logit_bias",
]: ]:
self_val = getattr(self, item, None) setattr(self, item, getattr(self, item)[new_indices])
# logit_bias can be None
if self_val is not None:
setattr(self, item, self_val[new_indices])
def merge(self, other: "Batch"): def merge(self, other):
self.reqs.extend(other.reqs) self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat( self.req_pool_indices = torch.concat(
...@@ -463,34 +456,17 @@ class Batch: ...@@ -463,34 +456,17 @@ class Batch:
"top_ks", "top_ks",
"frequency_penalties", "frequency_penalties",
"presence_penalties", "presence_penalties",
"logit_bias",
]: ]:
self_val = getattr(self, item, None) setattr(
other_val = getattr(other, item, None) self, item, torch.concat([getattr(self, item), getattr(other, item)])
setattr(self, item, torch.concat([self_val, other_val]))
# logit_bias can be None
if self.logit_bias is not None or other.logit_bias is not None:
vocab_size = (
self.logit_bias.shape[1]
if self.logit_bias is not None
else other.logit_bias.shape[1]
) )
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
if other.logit_bias is None:
other.logit_bias = torch.zeros(
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
def sample(self, logits: torch.Tensor): def sample(self, logits: torch.Tensor):
# Post process logits # Post process logits
logits = logits.contiguous() logits = logits.contiguous()
logits.div_(self.temperatures) logits.div_(self.temperatures)
if self.logit_bias is not None: logits.add_(self.logit_bias)
logits.add_(self.logit_bias)
has_regex = any(req.regex_fsm is not None for req in self.reqs) has_regex = any(req.regex_fsm is not None for req in self.reqs)
if has_regex: if has_regex:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment