permute_ex.py 1.29 KB
Newer Older
Astha Rai's avatar
Astha Rai committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import jinja2

EXTRA_SHAPE_TEMPLATE = jinja2.Template(
    """
{{indent}}const int64_t stride_a = *a_dim1;
{{indent}}const int64_t stride_b = *b_dim1;
{{indent}}const int64_t stride_c = *c_dim1;
    ck::index_t M0 = M / G1 / G2;
    ck::index_t M1 = G1;
    ck::index_t M2 = G2;
    ck::index_t N0 = G3;
    ck::index_t N1 = N / G3;
    // GEMM shape
    //ck::index_t M = M0 * M1 * M2;
    //ck::index_t N = N0 * N1;
    //ck::index_t K = 128;
    //ck::index_t stride_A = K;
    //ck::index_t stride_B = K;
    // E = [M0, N0, M1, N1, M2]
    /* 0, 3, 1, 4, 2
    ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
    ck::index_t stride_E_M1 = N1 * M2;
    ck::index_t stride_E_M2 = 1;
    ck::index_t stride_E_N0 = M1 * N1 * M2;
    ck::index_t stride_E_N1 = M2;
    */
    // E = [M2, M0, N0, M1, N1] 2, 0, 3, 1, 4
    ck::index_t stride_E_M0 = N0* M1* N1;
    ck::index_t stride_E_M1 = N1;
    ck::index_t stride_E_M2 = M0* N0* M1* N1;
    ck::index_t stride_E_N0 = M1 * N1;
    ck::index_t stride_E_N1 = 1;
    // D = [0, N0, 0, N1, 0]
    ck::index_t stride_D_M0 = 0;
    ck::index_t stride_D_M1 = 0;
    ck::index_t stride_D_M2 = 0;
    ck::index_t stride_D_N0 = N1;
    ck::index_t stride_D_N1 = 1;
"""
)

output = EXTRA_SHAPE_TEMPLATE.render(indent=" ");
print (output)