device.cpp 3.31 KB
Newer Older
1
#include <chrono>
Chao Liu's avatar
Chao Liu committed
2
3
4
5
6
7
8
9
10
#include "device.hpp"

DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
    hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}

void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }

Chao Liu's avatar
Chao Liu committed
11
12
std::size_t DeviceMem::GetBufferSize() { return mMemSize; }

Chao Liu's avatar
Chao Liu committed
13
14
15
16
17
18
19
20
21
22
23
void DeviceMem::ToDevice(const void* p)
{
    hipGetErrorString(
        hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}

void DeviceMem::FromDevice(void* p)
{
    hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}

Chao Liu's avatar
Chao Liu committed
24
25
void DeviceMem::SetZero() { hipGetErrorString(hipMemset(mpDeviceBuf, 0, mMemSize)); }

26
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
Chao Liu's avatar
Chao Liu committed
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
DeviceAlignedMemCPU::DeviceAlignedMemCPU(std::size_t mem_size, std::size_t alignment)
    : mMemSize(mem_size), mAlignment(alignment)
{
    assert(!(alignment == 0 || (alignment & (alignment - 1)))); // check pow of 2

    void* p1;
    void** p2;
    int offset = alignment - 1 + sizeof(void*);
    p1         = malloc(mem_size + offset);
    assert(p1 != nullptr);

    p2     = reinterpret_cast<void**>((reinterpret_cast<size_t>(p1) + offset) & ~(alignment - 1));
    p2[-1] = p1;
    mpDeviceBuf = reinterpret_cast<void*>(p2);
}

void* DeviceAlignedMemCPU::GetDeviceBuffer() { return mpDeviceBuf; }

std::size_t DeviceAlignedMemCPU::GetBufferSize() { return mMemSize; }

void DeviceAlignedMemCPU::SetZero() { memset(mpDeviceBuf, 0, mMemSize); }

DeviceAlignedMemCPU::~DeviceAlignedMemCPU() { free((reinterpret_cast<void**>(mpDeviceBuf))[-1]); }

Chao Liu's avatar
Chao Liu committed
52
53
54
55
struct KernelTimerImpl
{
    KernelTimerImpl()
    {
Chao Liu's avatar
Chao Liu committed
56
57
        hipGetErrorString(hipEventCreate(&mStart));
        hipGetErrorString(hipEventCreate(&mEnd));
Chao Liu's avatar
Chao Liu committed
58
59
60
61
    }

    ~KernelTimerImpl()
    {
Chao Liu's avatar
Chao Liu committed
62
63
        hipGetErrorString(hipEventDestroy(mStart));
        hipGetErrorString(hipEventDestroy(mEnd));
Chao Liu's avatar
Chao Liu committed
64
65
66
67
    }

    void Start()
    {
Chao Liu's avatar
Chao Liu committed
68
69
        hipGetErrorString(hipDeviceSynchronize());
        hipGetErrorString(hipEventRecord(mStart, nullptr));
Chao Liu's avatar
Chao Liu committed
70
71
72
73
    }

    void End()
    {
Chao Liu's avatar
Chao Liu committed
74
75
        hipGetErrorString(hipEventRecord(mEnd, nullptr));
        hipGetErrorString(hipEventSynchronize(mEnd));
Chao Liu's avatar
Chao Liu committed
76
77
78
79
80
    }

    float GetElapsedTime() const
    {
        float time;
Chao Liu's avatar
Chao Liu committed
81
        hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
Chao Liu's avatar
Chao Liu committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        return time;
    }

    hipEvent_t mStart, mEnd;
};

KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {}

KernelTimer::~KernelTimer() {}

void KernelTimer::Start() { impl->Start(); }

void KernelTimer::End() { impl->End(); }

float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); }
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

struct WallTimerImpl
{
    void Start() { mStart = std::chrono::high_resolution_clock::now(); }

    void End() { mStop = std::chrono::high_resolution_clock::now(); }

    float GetElapsedTime() const
    {
        return static_cast<float>(
                   std::chrono::duration_cast<std::chrono::microseconds>(mStop - mStart).count()) *
               1e-3;
    }

    std::chrono::time_point<std::chrono::high_resolution_clock> mStart;
    std::chrono::time_point<std::chrono::high_resolution_clock> mStop;
};

WallTimer::WallTimer() : impl(new WallTimerImpl()) {}

WallTimer::~WallTimer() {}

void WallTimer::Start() { impl->Start(); }

void WallTimer::End() { impl->End(); }

float WallTimer::GetElapsedTime() const { return impl->GetElapsedTime(); }