gpu.go 9.25 KB
Newer Older
1
2
3
4
5
//go:build linux || windows

package gpu

/*
6
7
8
#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
#cgo windows LDFLAGS: -lpthread

9
10
11
12
13
14
#include "gpu_info.h"

*/
import "C"
import (
	"fmt"
15
	"log/slog"
16
17
	"os"
	"path/filepath"
18
	"runtime"
Daniel Hiltgen's avatar
Daniel Hiltgen committed
19
	"strconv"
20
	"strings"
21
22
23
24
25
26
27
28
29
30
31
32
	"sync"
	"unsafe"
)

type handles struct {
	cuda *C.cuda_handle_t
	rocm *C.rocm_handle_t
}

var gpuMutex sync.Mutex
var gpuHandles *handles = nil

33
34
// With our current CUDA compile flags, 5.2 and older will not work properly
const CudaComputeMajorMin = 6
35

36
37
38
39
40
41
// Possible locations for the nvidia-ml library
var CudaLinuxGlobs = []string{
	"/usr/local/cuda/lib64/libnvidia-ml.so*",
	"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
	"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
	"/usr/lib/wsl/lib/libnvidia-ml.so*",
Daniel Hiltgen's avatar
Daniel Hiltgen committed
42
	"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
43
44
45
46
47
	"/opt/cuda/lib64/libnvidia-ml.so*",
	"/usr/lib*/libnvidia-ml.so*",
	"/usr/local/lib*/libnvidia-ml.so*",
	"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
	"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
48
49
50

	// TODO: are these stubs ever valid?
	"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
51
52
53
54
55
56
57
58
59
60
61
62
63
64
}

var CudaWindowsGlobs = []string{
	"c:\\Windows\\System32\\nvml.dll",
}

var RocmLinuxGlobs = []string{
	"/opt/rocm*/lib*/librocm_smi64.so*",
}

var RocmWindowsGlobs = []string{
	"c:\\Windows\\System32\\rocm_smi64.dll",
}

65
66
// Note: gpuMutex must already be held
func initGPUHandles() {
67

68
	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
69

70
	gpuHandles = &handles{nil, nil}
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
	var cudaMgmtName string
	var cudaMgmtPatterns []string
	var rocmMgmtName string
	var rocmMgmtPatterns []string
	switch runtime.GOOS {
	case "windows":
		cudaMgmtName = "nvml.dll"
		cudaMgmtPatterns = make([]string, len(CudaWindowsGlobs))
		copy(cudaMgmtPatterns, CudaWindowsGlobs)
		rocmMgmtName = "rocm_smi64.dll"
		rocmMgmtPatterns = make([]string, len(RocmWindowsGlobs))
		copy(rocmMgmtPatterns, RocmWindowsGlobs)
	case "linux":
		cudaMgmtName = "libnvidia-ml.so"
		cudaMgmtPatterns = make([]string, len(CudaLinuxGlobs))
		copy(cudaMgmtPatterns, CudaLinuxGlobs)
		rocmMgmtName = "librocm_smi64.so"
		rocmMgmtPatterns = make([]string, len(RocmLinuxGlobs))
		copy(rocmMgmtPatterns, RocmLinuxGlobs)
	default:
		return
	}

94
	slog.Info("Detecting GPU type")
95
96
97
98
	cudaLibPaths := FindGPULibs(cudaMgmtName, cudaMgmtPatterns)
	if len(cudaLibPaths) > 0 {
		cuda := LoadCUDAMgmt(cudaLibPaths)
		if cuda != nil {
99
			slog.Info("Nvidia GPU detected")
100
101
102
103
			gpuHandles.cuda = cuda
			return
		}
	}
104

105
106
107
108
	rocmLibPaths := FindGPULibs(rocmMgmtName, rocmMgmtPatterns)
	if len(rocmLibPaths) > 0 {
		rocm := LoadROCMMgmt(rocmLibPaths)
		if rocm != nil {
109
			slog.Info("Radeon GPU detected")
110
111
			gpuHandles.rocm = rocm
			return
112
113
114
115
116
117
118
119
120
121
122
123
124
125
		}
	}
}

func GetGPUInfo() GpuInfo {
	// TODO - consider exploring lspci (and equivalent on windows) to check for
	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
	gpuMutex.Lock()
	defer gpuMutex.Unlock()
	if gpuHandles == nil {
		initGPUHandles()
	}

	var memInfo C.mem_info_t
126
	resp := GpuInfo{}
127
128
	if gpuHandles.cuda != nil {
		C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
129
		if memInfo.err != nil {
130
			slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
131
132
			C.free(unsafe.Pointer(memInfo.err))
		} else {
133
134
135
136
			// Verify minimum compute capability
			var cc C.cuda_compute_capability_t
			C.cuda_compute_capability(*gpuHandles.cuda, &cc)
			if cc.err != nil {
137
				slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
138
139
				C.free(unsafe.Pointer(cc.err))
			} else if cc.major >= CudaComputeMajorMin {
140
				slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
141
142
				resp.Library = "cuda"
			} else {
143
				slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
144
			}
145
		}
146
147
	} else if gpuHandles.rocm != nil {
		C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
148
		if memInfo.err != nil {
149
			slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
150
			C.free(unsafe.Pointer(memInfo.err))
Daniel Hiltgen's avatar
Daniel Hiltgen committed
151
152
153
		} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
			// Only one GPU detected and it appears to be an integrated GPU - skip it
			slog.Info("ROCm unsupported integrated GPU detected")
154
		} else {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
			if memInfo.igpu_index >= 0 {
				// We have multiple GPUs reported, and one of them is an integrated GPU
				// so we have to set the env var to bypass it
				// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
				val := os.Getenv("ROCR_VISIBLE_DEVICES")
				if val == "" {
					devices := []string{}
					for i := 0; i < int(memInfo.count); i++ {
						if i == int(memInfo.igpu_index) {
							continue
						}
						devices = append(devices, strconv.Itoa(i))
					}
					val = strings.Join(devices, ",")
					os.Setenv("ROCR_VISIBLE_DEVICES", val)
				}
				slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
			}
173
			resp.Library = "rocm"
174
175
176
177
178
179
			var version C.rocm_version_resp_t
			C.rocm_get_version(*gpuHandles.rocm, &version)
			verString := C.GoString(version.str)
			if version.status == 0 {
				resp.Variant = "v" + verString
			} else {
180
				slog.Info(fmt.Sprintf("failed to look up ROCm version: %s", verString))
181
182
			}
			C.free(unsafe.Pointer(version.str))
183
184
		}
	}
185
	if resp.Library == "" {
186
		C.cpu_check_ram(&memInfo)
187
188
		resp.Library = "cpu"
		resp.Variant = GetCPUVariant()
189
190
	}
	if memInfo.err != nil {
191
		slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
192
		C.free(unsafe.Pointer(memInfo.err))
193
		return resp
194
	}
195
196

	resp.DeviceCount = uint32(memInfo.count)
197
198
199
200
201
	resp.FreeMemory = uint64(memInfo.free)
	resp.TotalMemory = uint64(memInfo.total)
	return resp
}

202
203
204
205
206
207
208
209
210
211
212
213
214
func getCPUMem() (memInfo, error) {
	var ret memInfo
	var info C.mem_info_t
	C.cpu_check_ram(&info)
	if info.err != nil {
		defer C.free(unsafe.Pointer(info.err))
		return ret, fmt.Errorf(C.GoString(info.err))
	}
	ret.FreeMemory = uint64(info.free)
	ret.TotalMemory = uint64(info.total)
	return ret, nil
}

215
216
func CheckVRAM() (int64, error) {
	gpuInfo := GetGPUInfo()
217
	if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
218
		// leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
219
220
		overhead := gpuInfo.FreeMemory / 10
		gpus := uint64(gpuInfo.DeviceCount)
221
222
		if overhead < gpus*1024*1024*1024 {
			overhead = gpus * 1024 * 1024 * 1024
223
		}
Daniel Hiltgen's avatar
Daniel Hiltgen committed
224
225
226
		avail := int64(gpuInfo.FreeMemory - overhead)
		slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
		return avail, nil
227
228
	}

229
	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
230
}
231
232
233
234
235

func FindGPULibs(baseLibName string, patterns []string) []string {
	// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
	var ldPaths []string
	gpuLibPaths := []string{}
236
	slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

	switch runtime.GOOS {
	case "windows":
		ldPaths = strings.Split(os.Getenv("PATH"), ";")
	case "linux":
		ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
	default:
		return gpuLibPaths
	}
	// Start with whatever we find in the PATH/LD_LIBRARY_PATH
	for _, ldPath := range ldPaths {
		d, err := filepath.Abs(ldPath)
		if err != nil {
			continue
		}
		patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
	}
254
	slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
	for _, pattern := range patterns {
		// Ignore glob discovery errors
		matches, _ := filepath.Glob(pattern)
		for _, match := range matches {
			// Resolve any links so we don't try the same lib multiple times
			// and weed out any dups across globs
			libPath := match
			tmp := match
			var err error
			for ; err == nil; tmp, err = os.Readlink(libPath) {
				if !filepath.IsAbs(tmp) {
					tmp = filepath.Join(filepath.Dir(libPath), tmp)
				}
				libPath = tmp
			}
			new := true
			for _, cmp := range gpuLibPaths {
				if cmp == libPath {
					new = false
					break
				}
			}
			if new {
				gpuLibPaths = append(gpuLibPaths, libPath)
			}
		}
	}
282
	slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
283
284
285
286
287
	return gpuLibPaths
}

func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
	var resp C.cuda_init_resp_t
288
	resp.ch.verbose = getVerboseState()
289
290
291
292
293
	for _, libPath := range cudaLibPaths {
		lib := C.CString(libPath)
		defer C.free(unsafe.Pointer(lib))
		C.cuda_init(lib, &resp)
		if resp.err != nil {
294
			slog.Info(fmt.Sprintf("Unable to load CUDA management library %s: %s", libPath, C.GoString(resp.err)))
295
296
297
298
299
300
301
302
303
304
			C.free(unsafe.Pointer(resp.err))
		} else {
			return &resp.ch
		}
	}
	return nil
}

func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
	var resp C.rocm_init_resp_t
305
	resp.rh.verbose = getVerboseState()
306
307
308
309
310
	for _, libPath := range rocmLibPaths {
		lib := C.CString(libPath)
		defer C.free(unsafe.Pointer(lib))
		C.rocm_init(lib, &resp)
		if resp.err != nil {
311
			slog.Info(fmt.Sprintf("Unable to load ROCm management library %s: %s", libPath, C.GoString(resp.err)))
312
313
314
315
316
317
318
			C.free(unsafe.Pointer(resp.err))
		} else {
			return &resp.rh
		}
	}
	return nil
}
319
320
321
322
323
324
325

func getVerboseState() C.uint16_t {
	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
		return C.uint16_t(1)
	}
	return C.uint16_t(0)
}