traits.h 5.41 KB
Newer Older
zhanghj2's avatar
zhanghj2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#pragma once

#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/barrier.h>

#include "config.h"

using namespace cute;

template<typename InputT_, bool Is_causal_>
struct Traits {
    using InputT = InputT_;
    static constexpr bool Is_causal = Is_causal_;
    static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
    static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
    static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
    static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;    
    static constexpr int NUM_THREADS = 256;
    static_assert(std::is_same_v<InputT, cutlass::bfloat16_t>);
    static constexpr int kBlockM = BLOCK_SIZE_M;
    static constexpr int kBlockN = PAGE_BLOCK_SIZE;
    static constexpr int kHeadDim = HEAD_DIM_K;
    static constexpr int kHeadDimV = HEAD_DIM_V;
    static constexpr int kNWarps = 4;
    static constexpr int kSwizzle = 3;
    using Element = InputT;
    using elem_type = Element;
    using ElementAccum = float;
    using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
    using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>; 
    using SmemLayoutAtomK = decltype(composition(
        Swizzle<3, 3, 3>{},
        Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, _1>>{}));
    using SmemLayoutK = decltype(tile_to_shape(
        SmemLayoutAtomK{},
        Shape<Int<kBlockN>, Int<8 * 32>>{}));
   
    using SmemLayoutAtomV = SmemLayoutAtomK;
    using SmemLayoutV = decltype(tile_to_shape(
        SmemLayoutAtomV{},
        Shape<Int<kBlockN>, Int<8 * 32>>{}));

    using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
    using SmemLayoutP = decltype(tile_to_shape(
        SmemLayoutAtomP{},
        Shape<Int<4*16*16>>{}));
    using SmemLayoutVtransposed = decltype(
        composition(SmemLayoutV{}, make_layout(Shape<Int<8 * 32>, Int<kBlockN>>{}, GenRowMajor{})));
    using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));

    using SmemLayoutAtomV_fp8 = Layout<Shape<Int<kBlockN>, Int<512>>, Stride<_512, _1>>; 
    using SmemLayoutV_fp8 = decltype(tile_to_shape(
        SmemLayoutAtomV_fp8{},
        Shape<Int<kBlockN>, Int<512>>{}));
    using SmemLayoutVtransposed_fp8 = decltype(
        composition(SmemLayoutV_fp8{}, make_layout(Shape<Int<512>, Int<kBlockN>>{}, GenRowMajor{})));
    using SmemLayoutVtransposedNoSwizzle_fp8 = decltype(get_nonswizzle_portion(SmemLayoutVtransposed_fp8{}));
    
    using SmemLayoutAtomQ = decltype(composition(
        Swizzle<kSwizzle, 3, 3>{},
        Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
    using SmemLayoutQ = decltype(tile_to_shape(
        SmemLayoutAtomQ{},
        Shape<Int<kBlockM>, Int<kHeadDim>>{}));

    using MMA_Atom_Arch_16_16_32 = std::conditional_t<
        std::is_same_v<elem_type, cutlass::half_t>,
        MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
        MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
    >;
    using TiledMma_16_16_32 = TiledMMA<
        MMA_Atom_Arch_16_16_32,
        Layout<Shape<_1, Int<4>, _1>>,  // 1x4x1 or 1x8x1 thread group
        ValLayoutMNK>;
    using MMA_Atom_Arch = std::conditional_t<
        std::is_same_v<elem_type, cutlass::half_t>,
        MMA_Atom<GFX928_16x16x64_F32F16F16F32_NT>,
        MMA_Atom<GFX928_16x16x64_F32BF16BF16F32_NT>
    >;
    
    using MMA_Atom_Arch_16_32_16 = std::conditional_t<
        std::is_same_v<elem_type, cutlass::half_t>,
        MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
        MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
    >;

    using TiledMma_O_16_32_16 = TiledMMA<
        MMA_Atom_Arch_16_32_16,
        Layout<Shape<_1, Int<kNWarps>, _1>>,  // 1x4x1 or 1x8x1 thread group
        ValLayoutMNK>;

    using TiledMma = TiledMMA<
        MMA_Atom_Arch,
        Layout<Shape<_1, Int<kNWarps>, _1>>,  // 1x4x1 or 1x8x1 thread group
        ValLayoutMNK>;
    using MMA_Atom_Arch_int8 = std::conditional_t<
        std::is_same_v<elem_type, cutlass::half_t>,
        MMA_Atom<GFX928_16x16x64_F32F16uint8F32_NT>,
        MMA_Atom<GFX928_16x16x64_F32BF16int8F32_NT>
    >;
    using MMA_Atom_Arch_16x64 = std::conditional_t<
        std::is_same_v<elem_type, cutlass::half_t>,
        MMA_Atom<GFX928_16x64x16_FP8_F32F16F16F32_NT>,
        MMA_Atom<GFX928_16x64x16_FP8_F32BF16BF16F32_NT>
    >;

    using TiledMma_O = TiledMMA<
        MMA_Atom_Arch_16x64,
        Layout<Shape<_1, Int<kNWarps>, _1>>,  // 1x4x1 or 1x8x1 thread group
        ValLayoutMNK>;

    using TiledMma_int8 = TiledMMA<
        MMA_Atom_Arch_int8,
        Layout<Shape<_1, Int<kNWarps>, _1>>,  // 1x4x1 or 1x8x1 thread group
        ValLayoutMNK>;
    
    using GmemLayoutAtomQ = Layout<Shape <_32, _8>,  
        Stride< _8, _1>>;
    using GmemTiledCopyQ = decltype(
        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
            GmemLayoutAtomQ{},
            Layout<Shape<_1, _8>>{})); 

    
    struct SharedMemoryPlan {
        union {
            struct {
                cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;  // Double buffer
                cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
                cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
                cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;

            };
            struct {
                cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
            };

        };



    };

};