utils.hpp 1.46 KB
Newer Older
1
2
#pragma once

3
#include <torch/extension.h>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include "cutlass/numeric_types.h"
#include "helper.h"

template <typename T>
struct cutlass_dtype {
  using type = T;
};

template <>
struct cutlass_dtype<half> {
  using type = cutlass::half_t;
};

template <>
struct cutlass_dtype<nv_bfloat16> {
  using type = cutlass::bfloat16_t;
};

template <>
struct cutlass_dtype<__nv_fp8_e4m3> {
  using type = cutlass::float_e4m3_t;
};

template <>
struct cutlass_dtype<__nv_fp8_e5m2> {
  using type = cutlass::float_e5m2_t;
};

template <typename T>
using cutlass_dtype_t = typename cutlass_dtype<T>::type;

template<typename T>
struct DeviceAllocation {
  T* ptr_ = nullptr;
  size_t offset_ = 0;
  size_t size_ = 0;
40
  torch::Tensor tensor;
41
42
43
44
45
46

  DeviceAllocation(DeviceAllocation const&) = delete;
  DeviceAllocation& operator=(DeviceAllocation const&) = delete;

  DeviceAllocation() = default;
  DeviceAllocation(size_t size) { reset(size); }
47
  ~DeviceAllocation() {}
48
49

  void reset(size_t size, size_t offset=0) {
50
51
52
53
54
    size_t num_element = sizeof(T) * (size + offset);
    auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);

    tensor = torch::empty(num_element, options);
    ptr_ = tensor.data_ptr<T>();
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    size_ = size;
    offset_ = offset;
  }

  T* get() {
    return ptr_ + offset_;
  }

  const T* get() const {
    return ptr_ + offset_;
  }

  size_t size() const { return size_; }

  size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
70
};