test_cublaslt_function.py 4.89 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for cublaslt-gemm benchmark."""

import unittest
7
from types import GeneratorType, SimpleNamespace
8
9
10
11

from tests.helper.testcase import BenchmarkTestCase
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform
from superbench.benchmarks.result import BenchmarkResult
12
from superbench.benchmarks.micro_benchmarks.blaslt_function_base import mrange, validate_mrange
13
14
15
16
17
18
19
20
21
22


class CublasLtBenchmarkTestCase(BenchmarkTestCase, unittest.TestCase):
    """Class for cublaslt-gemm benchmark test cases."""
    @classmethod
    def setUpClass(cls):
        """Hook method for setting up class fixture before running tests in the class."""
        super().setUpClass()
        cls.benchmark_name = 'cublaslt-gemm'
        cls.createMockEnvs(cls)
23
24
25
26
27
28
        cls.createMockFiles(cls, ['bin/cublaslt_gemm'])

    def get_benchmark(self):
        """Get Benchmark."""
        (benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
        return benchmark_cls(self.benchmark_name, parameters='')
29
30
31
32
33
34
35
36
37
38

    def test_cublaslt_gemm_cls(self):
        """Test cublaslt-gemm benchmark class."""
        for platform in Platform:
            (benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, platform)
            if platform is Platform.CUDA:
                self.assertIsNotNone(benchmark_cls)
            else:
                self.assertIsNone(benchmark_cls)

39
40
    def test_mrange(self):
        """Test mrange generation."""
41
42
43
44
45
46
47
48
49
50
51
52
        self.assertIsInstance(mrange(1), GeneratorType)
        self.assertListEqual([4, 8, 16, 32], list(mrange(4, 32, 2)))
        self.assertListEqual([2, 4, 8, 16], list(mrange(2, 31, 2)))
        self.assertListEqual([2, 4, 8], list(mrange(2, 8)))
        self.assertListEqual([2], list(mrange(2, 0, 2)))
        self.assertListEqual([2], list(mrange(2)))
        self.assertListEqual([2], list(mrange(2, 4, 1)))
        self.assertListEqual([2], list(mrange(2, 4, 0)))
        self.assertListEqual([0], list(mrange(0, 0)))
        self.assertListEqual([0], list(mrange(0)))
        self.assertListEqual([4, 8, 16, 32], list(mrange(4, 32, 2, 'x')))
        self.assertListEqual([4, 8, 12, 16, 20, 24, 28, 32], list(mrange(4, 32, 4, '+')))
53
54
55

    def test_validate_mrange(self):
        """Test mrange validation."""
56
57
58
59
60
61
62
        self.assertTrue(validate_mrange('2:32:2'))
        self.assertTrue(validate_mrange('4:32'))
        self.assertTrue(validate_mrange('8'))
        self.assertFalse(validate_mrange('2:32:2:4'))
        self.assertFalse(validate_mrange('2.5:32'))
        self.assertFalse(validate_mrange('2:32:2:x4'))
        self.assertFalse(validate_mrange('2:32:2:+4'))
63
64
65
66
67
68

    def test_cublaslt_gemm_command_generation(self):
        """Test cublaslt-gemm benchmark command generation."""
        (benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
        benchmark = benchmark_cls(
            self.benchmark_name,
69
            parameters='--batch 2:16:2 --shapes 2:4,4:8,8:32 32:128:4,128,128 --in_types fp16 fp32 fp64 int8',
70
71
        )
        self.assertTrue(benchmark._preprocess())
72
        self.assertEqual(4 * (2 * 2 * 3 + 2) * len(benchmark._args.in_types), len(benchmark._commands))
73
74
75
76

        def cmd(t, b, m, n, k):
            return f'{benchmark._CublasLtBenchmark__bin_path} -m {m} -n {n} -k {k} -b {b} -w 20 -i 50 -t {t}'

77
        for _t in ['fp16', 'fp32', 'fp64', 'int8']:
78
79
80
81
82
83
84
85
            for _b in [2, 4, 8, 16]:
                for _m in [2, 4]:
                    for _n in [4, 8]:
                        for _k in [8, 16, 32]:
                            self.assertIn(cmd(_t, _b, _m, _n, _k), benchmark._commands)
                for _m in [32, 128]:
                    self.assertIn(cmd(_t, _b, _m, 128, 128), benchmark._commands)

86
87
    def test_cublaslt_gemm_result_parsing(self):
        """Test cublaslt-gemm benchmark result parsing."""
88
89
90
        benchmark = self.get_benchmark()
        self.assertTrue(benchmark._preprocess())
        benchmark._args = SimpleNamespace(shapes=['16,16,16', '32,64,128'], in_types=['fp8e4m3'], log_raw_data=False)
91
92
93
94
95
96
97
98
99
        benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1)

        # Positive case - valid raw output
        self.assertTrue(benchmark._process_raw_result(0, '16   16    16    0       1.111      2.222'))
        self.assertTrue(benchmark._process_raw_result(1, '32   64    128    0       1.111      2.222'))
        self.assertEqual(ReturnCode.SUCCESS, benchmark.return_code)

        self.assertEqual(3, len(benchmark.result))
        for shape in benchmark._args.shapes:
100
            self.assertEqual(2.222, benchmark.result[f'fp8e4m3_0_{shape.replace(",", "_")}_flops'][0])
101
102
103

        # Negative case - invalid raw output
        self.assertFalse(benchmark._process_raw_result(1, 'cuBLAS API failed'))