amd_hip_windows.go 4.06 KB
Newer Older
Daniel Hiltgen's avatar
Daniel Hiltgen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
package gpu

import (
	"fmt"
	"log/slog"
	"syscall"
	"unsafe"

	"golang.org/x/sys/windows"
)

const (
	hipSuccess       = 0
	hipErrorNoDevice = 100
)

type hipDevicePropMinimal struct {
	Name        [256]byte
	unused1     [140]byte
	GcnArchName [256]byte // gfx####
	iGPU        int       // Doesn't seem to actually report correctly
	unused2     [128]byte
}

// Wrap the amdhip64.dll library for GPU discovery
type HipLib struct {
	dll                    windows.Handle
	hipGetDeviceCount      uintptr
	hipGetDeviceProperties uintptr
	hipMemGetInfo          uintptr
	hipSetDevice           uintptr
	hipDriverGetVersion    uintptr
}

func NewHipLib() (*HipLib, error) {
	h, err := windows.LoadLibrary("amdhip64.dll")
	if err != nil {
		return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
	}
	hl := &HipLib{}
	hl.dll = h
	hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount")
	if err != nil {
		return nil, err
	}
	hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties")
	if err != nil {
		return nil, err
	}
	hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo")
	if err != nil {
		return nil, err
	}
	hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice")
	if err != nil {
		return nil, err
	}
	hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion")
	if err != nil {
		return nil, err
	}
	return hl, nil
}

// The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup
// so we have to unload/reset the library after we do our initial discovery
// to make sure our updates to that variable are processed by llama.cpp
func (hl *HipLib) Release() {
	err := windows.FreeLibrary(hl.dll)
	if err != nil {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
71
		slog.Warn("failed to unload amdhip64.dll", "error", err)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
72
73
74
75
	}
	hl.dll = 0
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
76
func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
77
	if hl.dll == 0 {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
78
		return 0, 0, fmt.Errorf("dll has been unloaded")
Daniel Hiltgen's avatar
Daniel Hiltgen committed
79
80
81
82
	}
	var version int
	status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
	if status != hipSuccess {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
83
		return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
84
	}
Daniel Hiltgen's avatar
Daniel Hiltgen committed
85
86
87
88
89
90
91

	slog.Debug("hipDriverGetVersion", "version", version)
	// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
	driverMajor = version / 1000
	driverMinor = (version - (driverMajor * 1000)) / 10

	return driverMajor, driverMinor, nil
Daniel Hiltgen's avatar
Daniel Hiltgen committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
}

func (hl *HipLib) HipGetDeviceCount() int {
	if hl.dll == 0 {
		slog.Error("dll has been unloaded")
		return 0
	}
	var count int
	status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count)))
	if status == hipErrorNoDevice {
		slog.Info("AMD ROCm reports no devices found")
		return 0
	}
	if status != hipSuccess {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
106
		slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
	}
	return count
}

func (hl *HipLib) HipSetDevice(device int) error {
	if hl.dll == 0 {
		return fmt.Errorf("dll has been unloaded")
	}
	status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
	if status != hipSuccess {
		return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err)
	}
	return nil
}

func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
	if hl.dll == 0 {
		return nil, fmt.Errorf("dll has been unloaded")
	}
	var props hipDevicePropMinimal
	status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
	if status != hipSuccess {
		return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err)
	}
	return &props, nil
}

// free, total, err
func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
	if hl.dll == 0 {
		return 0, 0, fmt.Errorf("dll has been unloaded")
	}
	var totalMemory uint64
	var freeMemory uint64
	status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory)))
	if status != hipSuccess {
		return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err)
	}
	return freeMemory, totalMemory, nil
}