Commit 2ff1f5a2 authored by Leo Gao's avatar Leo Gao
Browse files

Add imports and fix comment

parent 8ef7a515
......@@ -5,6 +5,8 @@ import numpy as np
import re
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean
from lm_eval import utils
......@@ -130,7 +132,7 @@ class BaseLM(LM):
"""
pass
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks.
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
# TODO: enforce this somehow
def loglikelihood(self, requests):
......@@ -174,9 +176,6 @@ class BaseLM(LM):
return loglikelihoods
# subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
# TODO: enforce this somehow
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment