mha_common.h 5.27 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#pragma once
// SPDX-License-Identifier: MIT
 

// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif

#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

namespace aiter {
__global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state);

inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
    // If we have enough to almost fill the SMs, then just use 1 split
    if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
    max_splits = std::min({max_splits, num_SMs, num_n_blocks});
    float max_efficiency = 0.f;
    std::vector<float> efficiency;
    efficiency.reserve(max_splits);
    auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
    // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
    // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
    // (i.e. it's 11 splits anyway).
    // So we check if the number of blocks per split is the same as the previous num_splits.
    auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
        return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
    };
    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
        if (!is_split_eligible(num_splits)) {
            efficiency.push_back(0.f);
        } else {
            float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
            float eff = n_waves / ceil(n_waves);
            // printf("num_splits = %d, eff = %f\n", num_splits, eff);
            if (eff > max_efficiency) { max_efficiency = eff; }
            efficiency.push_back(eff);
        }
    }
    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
        if (!is_split_eligible(num_splits)) { continue; }
        if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
            // printf("num_splits chosen = %d\n", num_splits);
            return num_splits;
        }
    }
    return 1;
}

inline int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{
    int device;
    auto status = hipGetDevice(&device);
    if(status != hipSuccess)
        return num_splits;

    hipDeviceProp_t props{};
    status = hipGetDeviceProperties(&props, device);
    if(status != hipSuccess)
        return num_splits;

    // TODO - tile size should match the TileFmhaShape, hardcode for now
    const int kM0 = 128;
    const int kN1 = hdim_v;

    const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
    const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;

    if(num_splits < 1 && p_drop == 0.0f)
        return num_splits_heuristic_ck(
            batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);

    return num_splits;
}

template<typename ARG>
inline void print_fmha_fwd_args(ARG args)
{
    printf("seqlen_q = %d\n", args.seqlen_q);
    printf("seqlen_k = %d\n", args.seqlen_k);
    printf("batch = %d\n", args.batch);
    printf("max_seqlen_q = %d\n", args.max_seqlen_q);
    printf("hdim_q = %d\n", args.hdim_q);
    printf("hdim_v = %d\n", args.hdim_v);
    printf("nhead_q = %d\n", args.nhead_q);
    printf("nhead_k = %d\n", args.nhead_k);
    printf("scale_s = %f\n", args.scale_s);
    printf("scale_p = %f\n", args.scale_p);
    printf("scale_o = %f\n", args.scale_o);
    printf("stride_q = %d\n", args.stride_q);
    printf("stride_k = %d\n", args.stride_k);
    printf("stride_v = %d\n", args.stride_v);
    printf("stride_bias = %d\n", args.stride_bias);
    printf("stride_randval = %d\n", args.stride_randval);
    printf("stride_o = %d\n", args.stride_o);
    printf("nhead_stride_q = %d\n", args.nhead_stride_q);
    printf("nhead_stride_k = %d\n", args.nhead_stride_k);
    printf("nhead_stride_v = %d\n", args.nhead_stride_v);
    printf("nhead_stride_bias = %d\n", args.nhead_stride_bias);
    printf("nhead_stride_randval = %d\n", args.nhead_stride_randval);
    printf("nhead_stride_lse = %d\n", args.nhead_stride_lse);
    printf("nhead_stride_o = %d\n", args.nhead_stride_o);
    printf("batch_stride_q = %d\n", args.batch_stride_q);
    printf("batch_stride_k = %d\n", args.batch_stride_k);
    printf("batch_stride_v = %d\n", args.batch_stride_v);
    printf("batch_stride_bias = %d\n", args.batch_stride_bias);
    printf("batch_stride_randval = %d\n", args.batch_stride_randval);
    printf("batch_stride_lse = %d\n", args.batch_stride_lse);
    printf("batch_stride_o = %d\n", args.batch_stride_o);
    printf("window_size_left = %d\n", args.window_size_left);
    printf("window_size_right = %d\n", args.window_size_right);
    printf("mask_type = %d\n", args.mask_type);
    printf("p_drop = %f\n", args.p_drop);
    printf("s_randval = %d\n", args.s_randval);
}

} // namespace aiter