test_remote_load.py 1.57 KB
Newer Older
zzg_666's avatar
zzg_666 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from dataflow.operators.reasoning import (
    ReasoningAnswerNgramFilter
)

from dataflow.utils.storage import FileStorage
import pytest

class RemoteDataLoader():
    def __init__(self):
        
        self.storage_1 = FileStorage(
            first_entry_file_name="hf:openai/gsm8k:main:train",
            cache_path="./cache",
            file_name_prefix="dataflow_cache_step_hf",
            cache_type="jsonl",
        )

        self.storage_2 = FileStorage(
            first_entry_file_name="ms:modelscope/gsm8k:train",
            cache_path="./cache",
            file_name_prefix="dataflow_cache_step_ms",
            cache_type="jsonl",
        )

        self.answer_ngram_filter_step1 = ReasoningAnswerNgramFilter(
            min_score = 0.1,
            max_score = 1.0,
            ngrams = 5
        )
        
    def forward(self):
        self.answer_ngram_filter_step1.run(
            storage = self.storage_1.step(),
            input_question_key = "question",
            input_answer_key = "answer"
        )

        self.answer_ngram_filter_step1.run(
            storage = self.storage_2.step(),
            input_question_key = "question",
            input_answer_key = "answer"
        )
@pytest.mark.gpu  
def test_remote_data_loader():
    """
    Test function to run the RemoteDataLoader
    """
    import pytest
    try:
        loader = RemoteDataLoader()
        loader.forward()
    except Exception as e:
        pytest.fail(f"RemoteDataLoader execution failed with error: {e}")

if __name__ == "__main__":
    loader = RemoteDataLoader()
    loader.forward()