device_memory.cpp 1.29 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
5
6
#include "ck/host_utility/hip_check_error.hpp"

#include "ck/library/utility/device_memory.hpp"
Chao Liu's avatar
Chao Liu committed
7
8
9
10
11
12

DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
    hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}

carlushuang's avatar
carlushuang committed
13
14
15
16
17
18
19
20
21
22
void DeviceMem::Realloc(std::size_t mem_size)
{
    if(mpDeviceBuf)
    {
        hip_check_error(hipFree(mpDeviceBuf));
    }
    mMemSize = mem_size;
    hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}

23
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
Chao Liu's avatar
Chao Liu committed
24

25
std::size_t DeviceMem::GetBufferSize() const { return mMemSize; }
Chao Liu's avatar
Chao Liu committed
26

27
void DeviceMem::ToDevice(const void* p) const
Chao Liu's avatar
Chao Liu committed
28
{
carlushuang's avatar
carlushuang committed
29
30
31
32
33
    if(mpDeviceBuf)
    {
        hip_check_error(
            hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
    }
Chao Liu's avatar
Chao Liu committed
34
35
}

36
void DeviceMem::FromDevice(void* p) const
Chao Liu's avatar
Chao Liu committed
37
{
carlushuang's avatar
carlushuang committed
38
39
40
41
    if(mpDeviceBuf)
    {
        hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
    }
Chao Liu's avatar
Chao Liu committed
42
43
}

carlushuang's avatar
carlushuang committed
44
45
46
47
48
49
50
void DeviceMem::SetZero() const
{
    if(mpDeviceBuf)
    {
        hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize));
    }
}
Chao Liu's avatar
Chao Liu committed
51

carlushuang's avatar
carlushuang committed
52
53
54
55
56
57
58
DeviceMem::~DeviceMem()
{
    if(mpDeviceBuf)
    {
        hip_check_error(hipFree(mpDeviceBuf));
    }
}