gpu.go 10.5 KB
Newer Older
1
2
3
4
5
//go:build linux || windows

package gpu

/*
6
7
8
#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
#cgo windows LDFLAGS: -lpthread

9
10
11
12
13
14
#include "gpu_info.h"

*/
import "C"
import (
	"fmt"
15
	"log/slog"
16
17
	"os"
	"path/filepath"
18
	"runtime"
Daniel Hiltgen's avatar
Daniel Hiltgen committed
19
	"strconv"
20
	"strings"
21
22
23
24
25
26
27
28
29
30
31
32
	"sync"
	"unsafe"
)

type handles struct {
	cuda *C.cuda_handle_t
	rocm *C.rocm_handle_t
}

var gpuMutex sync.Mutex
var gpuHandles *handles = nil

33
34
// With our current CUDA compile flags, older than 5.0 will not work properly
var CudaComputeMin = [2]C.int{5, 0}
35

36
37
38
39
40
41
// Possible locations for the nvidia-ml library
var CudaLinuxGlobs = []string{
	"/usr/local/cuda/lib64/libnvidia-ml.so*",
	"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
	"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
	"/usr/lib/wsl/lib/libnvidia-ml.so*",
Daniel Hiltgen's avatar
Daniel Hiltgen committed
42
	"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
43
44
45
46
47
	"/opt/cuda/lib64/libnvidia-ml.so*",
	"/usr/lib*/libnvidia-ml.so*",
	"/usr/local/lib*/libnvidia-ml.so*",
	"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
	"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
48
49
50

	// TODO: are these stubs ever valid?
	"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
51
52
53
54
55
56
57
58
59
60
61
62
63
64
}

var CudaWindowsGlobs = []string{
	"c:\\Windows\\System32\\nvml.dll",
}

var RocmLinuxGlobs = []string{
	"/opt/rocm*/lib*/librocm_smi64.so*",
}

var RocmWindowsGlobs = []string{
	"c:\\Windows\\System32\\rocm_smi64.dll",
}

65
66
// Note: gpuMutex must already be held
func initGPUHandles() {
67

68
	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
69

70
	gpuHandles = &handles{nil, nil}
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
	var cudaMgmtName string
	var cudaMgmtPatterns []string
	var rocmMgmtName string
	var rocmMgmtPatterns []string
	switch runtime.GOOS {
	case "windows":
		cudaMgmtName = "nvml.dll"
		cudaMgmtPatterns = make([]string, len(CudaWindowsGlobs))
		copy(cudaMgmtPatterns, CudaWindowsGlobs)
		rocmMgmtName = "rocm_smi64.dll"
		rocmMgmtPatterns = make([]string, len(RocmWindowsGlobs))
		copy(rocmMgmtPatterns, RocmWindowsGlobs)
	case "linux":
		cudaMgmtName = "libnvidia-ml.so"
		cudaMgmtPatterns = make([]string, len(CudaLinuxGlobs))
		copy(cudaMgmtPatterns, CudaLinuxGlobs)
		rocmMgmtName = "librocm_smi64.so"
		rocmMgmtPatterns = make([]string, len(RocmLinuxGlobs))
		copy(rocmMgmtPatterns, RocmLinuxGlobs)
	default:
		return
	}

94
	slog.Info("Detecting GPU type")
95
96
97
98
	cudaLibPaths := FindGPULibs(cudaMgmtName, cudaMgmtPatterns)
	if len(cudaLibPaths) > 0 {
		cuda := LoadCUDAMgmt(cudaLibPaths)
		if cuda != nil {
99
			slog.Info("Nvidia GPU detected")
100
101
102
103
			gpuHandles.cuda = cuda
			return
		}
	}
104

105
106
107
108
	rocmLibPaths := FindGPULibs(rocmMgmtName, rocmMgmtPatterns)
	if len(rocmLibPaths) > 0 {
		rocm := LoadROCMMgmt(rocmLibPaths)
		if rocm != nil {
109
			slog.Info("Radeon GPU detected")
110
111
			gpuHandles.rocm = rocm
			return
112
113
114
115
116
117
118
119
120
121
122
123
124
		}
	}
}

func GetGPUInfo() GpuInfo {
	// TODO - consider exploring lspci (and equivalent on windows) to check for
	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
	gpuMutex.Lock()
	defer gpuMutex.Unlock()
	if gpuHandles == nil {
		initGPUHandles()
	}

125
	// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
126
	cpuVariant := GetCPUVariant()
127
	if cpuVariant == "" && runtime.GOARCH == "amd64" {
128
129
130
		slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
	}

131
	var memInfo C.mem_info_t
132
	resp := GpuInfo{}
133
	if gpuHandles.cuda != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
134
		C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
135
		if memInfo.err != nil {
136
			slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
137
			C.free(unsafe.Pointer(memInfo.err))
Daniel Hiltgen's avatar
Daniel Hiltgen committed
138
		} else if memInfo.count > 0 {
139
140
141
142
			// Verify minimum compute capability
			var cc C.cuda_compute_capability_t
			C.cuda_compute_capability(*gpuHandles.cuda, &cc)
			if cc.err != nil {
143
				slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
144
				C.free(unsafe.Pointer(cc.err))
145
			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
146
				slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
147
148
				resp.Library = "cuda"
			} else {
149
				slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
150
			}
151
		}
152
153
154
155
	} else if AMDDetected() && gpuHandles.rocm != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
		ver, err := AMDDriverVersion()
		if err == nil {
			slog.Info("AMD Driver: " + ver)
156
157
158
		} else {
			// For now this is benign, but we may eventually need to fail compatibility checks
			slog.Debug("error looking up amd driver version: %s", err)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
		}
		gfx := AMDGFXVersions()
		tooOld := false
		for _, v := range gfx {
			if v.Major < 9 {
				slog.Info("AMD GPU too old, falling back to CPU " + v.ToGFXString())
				tooOld = true
				break
			}

			// TODO - remap gfx strings for unsupporetd minor/patch versions to supported for the same major
			// e.g. gfx1034 works if we map it to gfx1030 at runtime

		}
		if !tooOld {
			// TODO - this algo can be shifted over to use sysfs instead of the rocm info library...
			C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
			if memInfo.err != nil {
				slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
				C.free(unsafe.Pointer(memInfo.err))
			} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
				// Only one GPU detected and it appears to be an integrated GPU - skip it
				slog.Info("ROCm unsupported integrated GPU detected")
			} else if memInfo.count > 0 {
				if memInfo.igpu_index >= 0 {
					// We have multiple GPUs reported, and one of them is an integrated GPU
					// so we have to set the env var to bypass it
					// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
					val := os.Getenv("ROCR_VISIBLE_DEVICES")
					if val == "" {
						devices := []string{}
						for i := 0; i < int(memInfo.count); i++ {
							if i == int(memInfo.igpu_index) {
								continue
							}
							devices = append(devices, strconv.Itoa(i))
Daniel Hiltgen's avatar
Daniel Hiltgen committed
195
						}
196
197
						val = strings.Join(devices, ",")
						os.Setenv("ROCR_VISIBLE_DEVICES", val)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
198
					}
199
					slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
Daniel Hiltgen's avatar
Daniel Hiltgen committed
200
				}
201
202
203
204
205
206
207
208
209
210
				resp.Library = "rocm"
				var version C.rocm_version_resp_t
				C.rocm_get_version(*gpuHandles.rocm, &version)
				verString := C.GoString(version.str)
				if version.status == 0 {
					resp.Variant = "v" + verString
				} else {
					slog.Info(fmt.Sprintf("failed to look up ROCm version: %s", verString))
				}
				C.free(unsafe.Pointer(version.str))
211
			}
212
213
		}
	}
214
	if resp.Library == "" {
215
		C.cpu_check_ram(&memInfo)
216
		resp.Library = "cpu"
217
		resp.Variant = cpuVariant
218
219
	}
	if memInfo.err != nil {
220
		slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
221
		C.free(unsafe.Pointer(memInfo.err))
222
		return resp
223
	}
224
225

	resp.DeviceCount = uint32(memInfo.count)
226
227
228
229
230
	resp.FreeMemory = uint64(memInfo.free)
	resp.TotalMemory = uint64(memInfo.total)
	return resp
}

231
232
233
234
235
236
237
238
239
240
241
242
243
func getCPUMem() (memInfo, error) {
	var ret memInfo
	var info C.mem_info_t
	C.cpu_check_ram(&info)
	if info.err != nil {
		defer C.free(unsafe.Pointer(info.err))
		return ret, fmt.Errorf(C.GoString(info.err))
	}
	ret.FreeMemory = uint64(info.free)
	ret.TotalMemory = uint64(info.total)
	return ret, nil
}

244
245
func CheckVRAM() (int64, error) {
	gpuInfo := GetGPUInfo()
246
	if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
247
		// leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
248
249
		overhead := gpuInfo.FreeMemory / 10
		gpus := uint64(gpuInfo.DeviceCount)
250
251
		if overhead < gpus*1024*1024*1024 {
			overhead = gpus * 1024 * 1024 * 1024
252
		}
Daniel Hiltgen's avatar
Daniel Hiltgen committed
253
254
255
		avail := int64(gpuInfo.FreeMemory - overhead)
		slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
		return avail, nil
256
257
	}

258
	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
259
}
260
261
262
263
264

func FindGPULibs(baseLibName string, patterns []string) []string {
	// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
	var ldPaths []string
	gpuLibPaths := []string{}
265
	slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

	switch runtime.GOOS {
	case "windows":
		ldPaths = strings.Split(os.Getenv("PATH"), ";")
	case "linux":
		ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
	default:
		return gpuLibPaths
	}
	// Start with whatever we find in the PATH/LD_LIBRARY_PATH
	for _, ldPath := range ldPaths {
		d, err := filepath.Abs(ldPath)
		if err != nil {
			continue
		}
		patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
	}
283
	slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
	for _, pattern := range patterns {
		// Ignore glob discovery errors
		matches, _ := filepath.Glob(pattern)
		for _, match := range matches {
			// Resolve any links so we don't try the same lib multiple times
			// and weed out any dups across globs
			libPath := match
			tmp := match
			var err error
			for ; err == nil; tmp, err = os.Readlink(libPath) {
				if !filepath.IsAbs(tmp) {
					tmp = filepath.Join(filepath.Dir(libPath), tmp)
				}
				libPath = tmp
			}
			new := true
			for _, cmp := range gpuLibPaths {
				if cmp == libPath {
					new = false
					break
				}
			}
			if new {
				gpuLibPaths = append(gpuLibPaths, libPath)
			}
		}
	}
311
	slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
312
313
314
315
316
	return gpuLibPaths
}

func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
	var resp C.cuda_init_resp_t
317
	resp.ch.verbose = getVerboseState()
318
319
320
321
322
	for _, libPath := range cudaLibPaths {
		lib := C.CString(libPath)
		defer C.free(unsafe.Pointer(lib))
		C.cuda_init(lib, &resp)
		if resp.err != nil {
323
			slog.Info(fmt.Sprintf("Unable to load CUDA management library %s: %s", libPath, C.GoString(resp.err)))
324
325
326
327
328
329
330
331
332
333
			C.free(unsafe.Pointer(resp.err))
		} else {
			return &resp.ch
		}
	}
	return nil
}

func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
	var resp C.rocm_init_resp_t
334
	resp.rh.verbose = getVerboseState()
335
336
337
338
339
	for _, libPath := range rocmLibPaths {
		lib := C.CString(libPath)
		defer C.free(unsafe.Pointer(lib))
		C.rocm_init(lib, &resp)
		if resp.err != nil {
340
			slog.Info(fmt.Sprintf("Unable to load ROCm management library %s: %s", libPath, C.GoString(resp.err)))
341
342
343
344
345
346
347
			C.free(unsafe.Pointer(resp.err))
		} else {
			return &resp.rh
		}
	}
	return nil
}
348
349
350
351
352
353
354

func getVerboseState() C.uint16_t {
	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
		return C.uint16_t(1)
	}
	return C.uint16_t(0)
}