dispatch_cutlass.h 448 Bytes
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#pragma once

#include "common.h"
#include "Tensor.h"

#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/bfloat16.h>

template<typename F>
inline void dispatchF16(Tensor::ScalarType type, F &&func) {
    if (type == Tensor::FP16) {
        func.template operator()<cutlass::half_t>();
    } else if (type == Tensor::BF16) {
        func.template operator()<cutlass::bfloat16_t>();
    } else {
        assert(false);
    }
Muyang Li's avatar
Muyang Li committed
19
}