"src/llamafactory/data/template.py" did not exist on "0d4db43f32cb3c472da832ad8586517f670235e2"
amd_hip_windows.go 4.13 KB
Newer Older
mashun1's avatar
v1  
mashun1 committed
1
2
3
package gpu

import (
xuxzh1's avatar
init  
xuxzh1 committed
4
	"errors"
mashun1's avatar
v1  
mashun1 committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
	"fmt"
	"log/slog"
	"syscall"
	"unsafe"

	"golang.org/x/sys/windows"
)

const (
	hipSuccess       = 0
	hipErrorNoDevice = 100
)

type hipDevicePropMinimal struct {
	Name        [256]byte
	unused1     [140]byte
	GcnArchName [256]byte // gfx####
	iGPU        int       // Doesn't seem to actually report correctly
	unused2     [128]byte
}

// Wrap the amdhip64.dll library for GPU discovery
type HipLib struct {
	dll                    windows.Handle
	hipGetDeviceCount      uintptr
	hipGetDeviceProperties uintptr
	hipMemGetInfo          uintptr
	hipSetDevice           uintptr
	hipDriverGetVersion    uintptr
}

func NewHipLib() (*HipLib, error) {
xuxzh1's avatar
init  
xuxzh1 committed
37
38
	// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
	h, err := windows.LoadLibrary("amdhip64_6.dll")
mashun1's avatar
v1  
mashun1 committed
39
	if err != nil {
xuxzh1's avatar
init  
xuxzh1 committed
40
		return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
mashun1's avatar
v1  
mashun1 committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
	}
	hl := &HipLib{}
	hl.dll = h
	hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount")
	if err != nil {
		return nil, err
	}
	hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties")
	if err != nil {
		return nil, err
	}
	hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo")
	if err != nil {
		return nil, err
	}
	hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice")
	if err != nil {
		return nil, err
	}
	hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion")
	if err != nil {
		return nil, err
	}
	return hl, nil
}

// The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup
// so we have to unload/reset the library after we do our initial discovery
// to make sure our updates to that variable are processed by llama.cpp
func (hl *HipLib) Release() {
	err := windows.FreeLibrary(hl.dll)
	if err != nil {
		slog.Warn("failed to unload amdhip64.dll", "error", err)
	}
	hl.dll = 0
}

func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
	if hl.dll == 0 {
xuxzh1's avatar
init  
xuxzh1 committed
80
		return 0, 0, errors.New("dll has been unloaded")
mashun1's avatar
v1  
mashun1 committed
81
82
83
84
85
86
87
88
	}
	var version int
	status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version)))
	if status != hipSuccess {
		return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err)
	}

	slog.Debug("hipDriverGetVersion", "version", version)
xuxzh1's avatar
init  
xuxzh1 committed
89
90
	driverMajor = version / 10000000
	driverMinor = (version - (driverMajor * 10000000)) / 100000
mashun1's avatar
v1  
mashun1 committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

	return driverMajor, driverMinor, nil
}

func (hl *HipLib) HipGetDeviceCount() int {
	if hl.dll == 0 {
		slog.Error("dll has been unloaded")
		return 0
	}
	var count int
	status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count)))
	if status == hipErrorNoDevice {
		slog.Info("AMD ROCm reports no devices found")
		return 0
	}
	if status != hipSuccess {
		slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
	}
	return count
}

func (hl *HipLib) HipSetDevice(device int) error {
	if hl.dll == 0 {
xuxzh1's avatar
init  
xuxzh1 committed
114
		return errors.New("dll has been unloaded")
mashun1's avatar
v1  
mashun1 committed
115
116
117
118
119
120
121
122
123
124
	}
	status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device))
	if status != hipSuccess {
		return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err)
	}
	return nil
}

func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) {
	if hl.dll == 0 {
xuxzh1's avatar
init  
xuxzh1 committed
125
		return nil, errors.New("dll has been unloaded")
mashun1's avatar
v1  
mashun1 committed
126
127
128
129
130
131
132
133
134
135
136
137
	}
	var props hipDevicePropMinimal
	status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device))
	if status != hipSuccess {
		return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err)
	}
	return &props, nil
}

// free, total, err
func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) {
	if hl.dll == 0 {
xuxzh1's avatar
init  
xuxzh1 committed
138
		return 0, 0, errors.New("dll has been unloaded")
mashun1's avatar
v1  
mashun1 committed
139
140
141
142
143
144
145
146
147
	}
	var totalMemory uint64
	var freeMemory uint64
	status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory)))
	if status != hipSuccess {
		return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err)
	}
	return freeMemory, totalMemory, nil
}