Commit e6decf02 authored by Leo Gao's avatar Leo Gao
Browse files

gpt2: extract model call

parent 7a39d68b
......@@ -155,7 +155,7 @@ class GPT2LM(LM):
contlens.append(cont)
inplens.append(inplen)
multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1).cpu() # [batch, seq, vocab]
multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)
......@@ -183,6 +183,16 @@ class GPT2LM(LM):
return reord.get_original(res)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
return self.gpt2(inps)[0][:, :, :50257]
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment