Unverified Commit e4a7b69f authored by Avelina9X's avatar Avelina9X Committed by GitHub
Browse files

Added softmax_dtype argument to HFLM to coerce log_softmax computations (#2921)



* Added softmax_dtype argument to coerce log_softmax computations

* move softmax_dtype

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent 930d8378
......@@ -74,6 +74,7 @@ class HFLM(TemplateLM):
max_length: Optional[int] = None,
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
softmax_dtype: Optional[Union[str, torch.dtype]] = None,
batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = False,
......@@ -234,6 +235,9 @@ class HFLM(TemplateLM):
self.batch_schedule = 1
self.batch_sizes = {}
self.max_batch_size = max_batch_size
self.softmax_dtype = (
get_dtype(softmax_dtype) if softmax_dtype is not None else None
)
if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":")
......@@ -768,7 +772,11 @@ class HFLM(TemplateLM):
(batch_size, max_length), device=self.device
).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841
out = F.log_softmax( # noqa: F841
self._model_call(test_batch, **call_kwargs),
dim=-1,
dtype=self.softmax_dtype,
)
return batch_size
......@@ -1200,7 +1208,9 @@ class HFLM(TemplateLM):
}
multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1
self._model_call(batched_inps, **call_kwargs),
dim=-1,
dtype=self.softmax_dtype,
) # [batch, padding_length (inp or cont), vocab]
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment