traits.h 3.67 KB
Newer Older
zhanghj2's avatar
zhanghj2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#pragma once

#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/barrier.h>

#include "config.h"

using namespace cute;

template<typename InputT_, bool Is_causal_>
struct Traits {
    using InputT = InputT_;
    static constexpr bool Is_causal = Is_causal_;
    static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
    static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
    static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
    static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;    
    static constexpr int NUM_THREADS = 256;
    static_assert(std::is_same_v<InputT, cutlass::float_e4m3_t>);
    static constexpr int kBlockM = BLOCK_SIZE_M;
    static constexpr int kBlockN = PAGE_BLOCK_SIZE;
    static constexpr int kHeadDim = HEAD_DIM_K;
    static constexpr int kHeadDimV = HEAD_DIM_V;
    static constexpr int kNWarps = 4;
    using Element = InputT;
    using elem_type = Element;
    using ElementAccum = float;

    using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>; 
    static constexpr int kSwizzle = 3;

zhanghj2's avatar
zhanghj2 committed
34
35
36
37
38
39
40
    using SmemLayoutAtomQ = 
        Layout<Shape<Int<16>, Int<64>>, Stride<Int<64>, _1>>;

    using SmemLayoutQ = decltype(tile_to_shape(
        SmemLayoutAtomQ{},
        Shape<Int<kBlockM>, Int<kHeadDim>>{}));

zhanghj2's avatar
zhanghj2 committed
41
42
43
44
45
46
47
    using SmemLayoutAtomK = decltype(composition(
        Swizzle<kSwizzle, 4, 3>{},
        Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));

    using SmemLayoutK = decltype(tile_to_shape(
        SmemLayoutAtomK{},
        Shape<Int<kBlockN>, Int<8 * 64>>{}));
zhanghj2's avatar
zhanghj2 committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
  
    using SmemLayoutAtomV = SmemLayoutAtomK;
    using SmemLayoutV = decltype(tile_to_shape(
        SmemLayoutAtomV{},
        Shape<Int<kBlockN>, Int<kHeadDimV>>{}));

    using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
    using SmemLayoutP = decltype(tile_to_shape(
        SmemLayoutAtomP{},
        Shape<Int<4*16*16>>{}));
    
    using SmemLayoutVtransposed = decltype(
        composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
    using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
    
    using SmemLayoutAtomK_place_holder = Layout<Shape<Int<kBlockN>, Int<64>>, Stride<_64, _1>>; 
    using SmemLayoutK_place_holder = decltype(tile_to_shape(
        SmemLayoutAtomK_place_holder{},
        Shape<Int<kBlockN>, Int<7*64>>{}));
zhanghj2's avatar
zhanghj2 committed
67

zhanghj2's avatar
zhanghj2 committed
68
69
    using MMA_Atom_Arch = MMA_Atom<GFX938_16x16x64_F32F8F8F32E4M3E4M3_NN_LIT>;
    using MMA_Atom_Arch_16x32 = MMA_Atom<GFX938_16x32x32_F32F8F8F32E4M3E4M3_NT_LIT>;
zhanghj2's avatar
zhanghj2 committed
70

zhanghj2's avatar
zhanghj2 committed
71
72
73
74
75
76
77
78
79
80
81
82
    using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;

    using TiledMma = TiledMMA<
        MMA_Atom_Arch,
        Layout<Shape<_1, Int<kNWarps>, _1>>,  // 1x4x1 or 1x8x1 thread group
        ValLayoutMNK>;//
    using TiledMma_O = TiledMMA<
        MMA_Atom_Arch_16x32,
        Layout<Shape<_1, Int<kNWarps>, _1>>,  // 1x4x1 or 1x8x1 thread group
        ValLayoutMNK>;
    
    
zhanghj2's avatar
zhanghj2 committed
83
    struct SharedMemoryPlan {
zhanghj2's avatar
zhanghj2 committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        union {
            struct {
                cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;  // Double buffer

            };
            struct {
                cute::array_aligned<Element, cute::cosize_v<SmemLayoutK_place_holder>> smem_temp;  // Double buffer
                cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
                cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
                cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
            };
            struct {
                cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
            };
       
        };

zhanghj2's avatar
zhanghj2 committed
101
102
103
104
105
106
    };

};