test_gemm_autotune.py 4.97 KB
Newer Older
yuguo's avatar
yuguo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information

import os, sys
import copy
import pytest
import tempfile
import shutil
import subprocess
import csv
import warnings

import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION

wenjh's avatar
wenjh committed
16
from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm
yuguo's avatar
yuguo committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143


def use_hipblaslt():
    return (os.getenv("NVTE_USE_HIPBLASLT") is not None
            or os.getenv("NVTE_USE_ROCBLAS") is None )


storage_fname = "te_algo"


def dump_storage(fname):
    print("========")
    with open(fname, "r") as ifile:
        for row in ifile:
            print(row)
    print("========")


def analyse_storage(fname):
    with open(fname, "r") as ifile:
        reader = csv.DictReader(ifile)
        next(reader)
        head = reader.fieldnames
    assert ("m" in head and "algo_id" in head and  "ws_min" in head and "ws_max" in head
            and "aidx" in head), "Invalid CSV format"
    return head

def read_storage(fname):
    data = []
    with open(fname, "r") as ifile:
        reader = csv.DictReader(ifile)
        for row in reader:
            data.append(row)
    return data


def write_storage(fname, head, data):
    with open(fname, "w") as ofile:
        writer = csv.DictWriter(ofile, fieldnames = head, lineterminator="\n")
        writer.writeheader()
        writer.writerows(data)


@pytest.mark.skipif(not use_hipblaslt(), reason="Autotune requires hipBLASLt")
@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="Autotune requires ROCm TE")
def test_gemm_autotune():
    storage_dir = tempfile.mkdtemp();
    fname = storage_dir+"/"+storage_fname
    script = os.path.abspath(__file__)
    try:
        os.environ["TE_HIPBLASLT_ALGO_LOAD"] = fname
        os.environ["TE_HIPBLASLT_ALGO_SAVE"] = fname
        run_args = ["python", script, "--run"]

        #Initial algo creation
        subprocess.run(run_args)
        head = analyse_storage(fname)
        algos = read_storage(fname)
        assert len(algos)==1, "Expected 1 cached record"
        algo0 = copy.copy(algos[0])

        ofile = fname+".1"
        os.environ["TE_HIPBLASLT_ALGO_SAVE"] = ofile

        #Unused cache entries
        algos[0]["m"] = "999"+algos[0]["m"] # fake record for different shape
        write_storage(fname, head, algos)
        subprocess.run(run_args)
        algos = read_storage(ofile)
        assert len(algos)==2, "Expected 2 cached records"
        assert algo0 == algos[1], "Invalid algo"

        #Adjust workspace size
        ws_max = int(algo0["ws_max"])
        if (ws_max > 0):
            algos=[copy.copy(algo0)]
            algos[0]["ws_max"] = str(ws_max - 1) # decrease WS range should restore size
            ws_min = int(algos[0]["ws_min"])
            if (ws_max - ws_min > 1):
                ws_min = ws_min + 1
                algos[0]["ws_min"] = str(ws_min)
            write_storage(fname, head, algos)
            subprocess.run(run_args)
            algos = read_storage(ofile)
            assert len(algos)==1, "Expected 1 cached record"
            assert (str(ws_min), str(ws_max)) == (algos[0]["ws_min"], algos[0]["ws_max"]), "Invalid WS size"
        else:
            warnings.warn("Cached algo Workspace size is 0")

        #Modify algo index
        algo_index = int(algo0["aidx"])
        algos=[copy.copy(algo0)]
        algos[0]["aidx"] = str(algo_index + 1);
        write_storage(fname, head, algos)
        subprocess.run(run_args)
        algos = read_storage(ofile)
        assert len(algos)==1, "Expected 1 cached record"
        assert (algo0["aidx"], algo0["algo_id"]) == (algos[0]["aidx"], algos[0]["algo_id"]), "Invalid algo IDX"

        # Configure autotune range so current cached algo is out of it 
        # and cache new value
        os.environ["TE_HIPBLASLT_ALGO_LOAD"] = ""
        os.environ["TE_HIPBLASLT_ALGO_SAVE"] = fname
        os.environ["TE_HIPBLASLT_ALGO_SELECTION"] = str(algo_index + 1)
        subprocess.run(run_args)
        algos = read_storage(fname)
        assert len(algos)==1, "Expected 1 cached record"
        algo1 = copy.copy(algos[0])
        assert algo0["algo_id"] != algo1["algo_id"], "Unexpected algo ID"

        #Restore autotune range begining, the new algo should still be used
        os.environ["TE_HIPBLASLT_ALGO_LOAD"] = fname
        del os.environ["TE_HIPBLASLT_ALGO_SELECTION"]
        subprocess.run(run_args)
        algos = read_storage(fname)
        assert len(algos)==1, "Expected 1 cached record"
        assert algo1 == algos[0], "Invalid algo ID"

    finally:
        shutil.rmtree(storage_dir)
        pass


def run_gemm():
    N = 32
    datatype = torch.float16    
    inp = torch.randn((N, N), device="cuda", dtype=datatype)
wenjh's avatar
wenjh committed
144
    _, _, _ = general_gemm(A=inp, B=inp, dtype=datatype)
yuguo's avatar
yuguo committed
145
146
147
148
149
150
151
152


if __name__ == "__main__":
    if sys.argv[1] == "--run":
        run_gemm()