amd_hip_windows.go 4.12 KB
Newer Older
Daniel Hiltgen's avatar
Daniel Hiltgen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
package gpu

import (
	"fmt"
	"log/slog"
	"syscall"
	"unsafe"

	"golang.org/x/sys/windows"
)

const (
	hipSuccess       = 0
	hipErrorNoDevice = 100
)

type hipDevicePropMinimal struct {
	Name        [256]byte
	unused1     [140]byte
	GcnArchName [256]byte // gfx####
	iGPU        int       // Doesn't seem to actually report correctly
	unused2     [128]byte
}

// Wrap the amdhip64.dll library for GPU discovery
type HipLib struct {
	dll                    windows.Handle
	hipGetDeviceCount      uintptr
	hipGetDeviceProperties uintptr
	hipMemGetInfo          uintptr
	hipSetDevice           uintptr
	hipDriverGetVersion    uintptr
}

func NewHipLib() (*HipLib, error) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
36
37
	// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
	h, err := windows.LoadLibrary("amdhip64_6.dll")
Daniel Hiltgen's avatar
Daniel Hiltgen committed
38
	if err != nil {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
39
		return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
	}
	hl := &HipLib{}
	hl.dll = h
	hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount")
	if err != nil {
		return nil, err
	}
	hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties")
	if err != nil {
		return nil, err
	}
	hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo")
	if err != nil {
		return nil, err
	}
	hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice")
	if err != nil {
		return nil, err
	}
	hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion")
	if err != nil {
		return nil, err
	}
	return hl, nil
}

// The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup
// so we have to unload/reset the library after we do our initial discovery
// to make sure our updates to that variable are processed by llama.cpp
func (hl *HipLib) Release() {
	err := windows.FreeLibrary(hl.dll)
	if err != nil {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
72
		slog.Warn("failed to unload amdhip64.dll", "error", err)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
73
74
75
76
	}
	hl.dll = 0
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
77
func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
78
	if hl.dll == 0 {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
79
		return 0, 0, fmt.Errorf("dll has been unloaded")
Daniel Hiltgen's avatar
Daniel Hiltgen committed
80
81
82
83
	}
	var version int
	status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
	if status != hipSuccess {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
84
		return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
85
	}
Daniel Hiltgen's avatar
Daniel Hiltgen committed
86
87

	slog.Debug("hipDriverGetVersion", "version", version)
88
89
	driverMajor = version / 10000000
	driverMinor = (version - (driverMajor * 10000000)) / 100000
Daniel Hiltgen's avatar
Daniel Hiltgen committed
90
91

	return driverMajor, driverMinor, nil
Daniel Hiltgen's avatar
Daniel Hiltgen committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
}

func (hl *HipLib) HipGetDeviceCount() int {
	if hl.dll == 0 {
		slog.Error("dll has been unloaded")
		return 0
	}
	var count int
	status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count)))
	if status == hipErrorNoDevice {
		slog.Info("AMD ROCm reports no devices found")
		return 0
	}
	if status != hipSuccess {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
106
		slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
	}
	return count
}

func (hl *HipLib) HipSetDevice(device int) error {
	if hl.dll == 0 {
		return fmt.Errorf("dll has been unloaded")
	}
	status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
	if status != hipSuccess {
		return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err)
	}
	return nil
}

func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
	if hl.dll == 0 {
		return nil, fmt.Errorf("dll has been unloaded")
	}
	var props hipDevicePropMinimal
	status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
	if status != hipSuccess {
		return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err)
	}
	return &props, nil
}

// free, total, err
func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
	if hl.dll == 0 {
		return 0, 0, fmt.Errorf("dll has been unloaded")
	}
	var totalMemory uint64
	var freeMemory uint64
	status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory)))
	if status != hipSuccess {
		return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err)
	}
	return freeMemory, totalMemory, nil
}