test_models.py 625 Bytes
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
4
import lm_eval.models as models
import lm_eval.base as base

def test_gpt2():
Leo Gao's avatar
Leo Gao committed
5
    gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu")
Leo Gao's avatar
Leo Gao committed
6
7
8
9
10
11
12
13
    (ll_dog, ig_dog), (ll_cat, ig_cat) = gpt2.loglikelihood([
        ('The quick brown fox jumps over the lazy', ' dog'),
        ('The quick brown fox jumps over the lazy', ' cat'),
    ])

    assert ll_dog > ll_cat
    assert not ig_cat

Leo Gao's avatar
Leo Gao committed
14
    # test empty context
Leo Gao's avatar
Leo Gao committed
15
16
17
18
19
20
21
    gpt2.loglikelihood([('', 'test')])

    gen, = gpt2.greedy_until([
        ('The quick brown fox jumps over the lazy', ['.', '\n'])
    ])

    assert gen == ', lazy fox and they both fall to the ground'