Commit 621214e1 authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

fix issues with encoder_attns, test lambada

parent 1c409035
...@@ -226,7 +226,8 @@ class HFLM(LM): ...@@ -226,7 +226,8 @@ class HFLM(LM):
logits returned from the model's decoder logits returned from the model's decoder
""" """
with torch.no_grad(): with torch.no_grad():
if attn_mask or labels: if attn_mask is not None or labels is not None:
assert attn_mask is not None and labels is not None
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
return self.model( return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels input_ids=inps, attention_mask=attn_mask, labels=labels
...@@ -394,6 +395,10 @@ class HFLM(LM): ...@@ -394,6 +395,10 @@ class HFLM(LM):
device=self.device, device=self.device,
) )
(inplen,) = inp.shape (inplen,) = inp.shape
# build encoder attn masks
encoder_attns.append(torch.ones_like(inp))
cont = torch.tensor( cont = torch.tensor(
(continuation_enc)[-self.max_length :], (continuation_enc)[-self.max_length :],
# TODO: left-shift these? # TODO: left-shift these?
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment