Commit a5b1c7a8 authored by Nathan Habib's avatar Nathan Habib
Browse files

fix dtype

parent bd6718e4
......@@ -292,6 +292,7 @@ class HFLM(TemplateLM):
self.batch_schedule = 1
self.batch_sizes = {}
self.max_batch_size = max_batch_size
self.dtype = get_dtype(dtype)
if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":")
......@@ -1124,11 +1125,10 @@ class HFLM(TemplateLM):
"labels": batched_conts,
}
logits = self._model_call(batched_inps, **call_kwargs)
multi_logits = F.log_softmax(
logits,
self._model_call(batched_inps, **call_kwargs),
dim=-1,
dtype=logits.dtype,
dtype=self.dtype,
) # [batch, padding_length (inp or cont), vocab]
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment