test_prints.py 2.31 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Union

import pytest

from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps import tasks as tasks_
from tests.unit.conftest import DummyLLM

# The tasks not listed here don't have a print method (or don't have a print method that works)
tasks = [
    tasks_.ComplexityScorer,
    partial(tasks_.EvolInstruct, num_evolutions=1),
    partial(tasks_.EvolComplexity, num_evolutions=1),
    partial(tasks_.EvolComplexityGenerator, num_instructions=1),
    partial(tasks_.EvolInstructGenerator, num_instructions=1),
    partial(tasks_.EvolQuality, num_evolutions=1),
    tasks_.Genstruct,
    partial(
        tasks_.BitextRetrievalGenerator,
        source_language="English",
        target_language="Spanish",
        unit="sentence",
        difficulty="elementary school",
        high_score="4",
        low_score="2.5",
    ),
    partial(tasks_.EmbeddingTaskGenerator, category="text-retrieval"),
    tasks_.GenerateLongTextMatchingData,
    tasks_.GenerateShortTextMatchingData,
    tasks_.GenerateTextClassificationData,
    tasks_.GenerateTextRetrievalData,
    tasks_.MonolingualTripletGenerator,
    tasks_.InstructionBacktranslation,
    tasks_.Magpie,
    tasks_.MagpieGenerator,
    partial(tasks_.PrometheusEval, mode="absolute", rubric="factual-validity"),
    tasks_.QualityScorer,
    tasks_.SelfInstruct,
    partial(tasks_.GenerateSentencePair, action="paraphrase"),
    tasks_.UltraFeedback,
    tasks_.URIAL,
]


class TestLLM(DummyLLM, MagpieChatTemplateMixin):
    magpie_pre_query_template: Union[str, None] = "llama3"


llm = TestLLM()


@pytest.mark.parametrize("task", tasks)
def test_prints(task) -> None:
    t = task(llm=llm)
    t.load()
    t.print()
    t.unload()