common.hpp 488 Bytes
Newer Older
1
2
3
#pragma once

#include "cutlass/cutlass.h"
4
#include <climits>
5
6
7
8
9
10
11
12
13

/**
 * Helper function for checking CUTLASS errors
 */
#define CUTLASS_CHECK(status)                        \
  {                                                  \
    TORCH_CHECK(status == cutlass::Status::kSuccess, \
                cutlassGetStatusString(status))      \
  }
14
15
16
17
18
19

inline uint32_t next_pow_2(uint32_t const num) {
  if (num <= 1) return num;
  return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}