test_models.py 370 Bytes
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import lm_eval.models as models
import lm_eval.base as base

def test_gpt2():
    gpt2 = models.get_model('gpt2')(device="cpu")
    (ll_dog, ig_dog), (ll_cat, ig_cat) = gpt2.loglikelihood([
        ('The quick brown fox jumps over the lazy', ' dog'),
        ('The quick brown fox jumps over the lazy', ' cat'),
    ])

    assert ll_dog > ll_cat
    assert not ig_cat