amd_linux.go 13.5 KB
Newer Older
Daniel Hiltgen's avatar
Daniel Hiltgen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
package gpu

import (
	"bufio"
	"errors"
	"fmt"
	"io"
	"log/slog"
	"os"
	"path/filepath"
	"slices"
	"strconv"
	"strings"
)

// Discovery logic for AMD/ROCm GPUs

const (
	DriverVersionFile     = "/sys/module/amdgpu/version"
	AMDNodesSysfsDir      = "/sys/class/kfd/kfd/topology/nodes/"
	GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties"

	// Prefix with the node dir
	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
	RocmStandardLocation   = "/opt/rocm/lib"
Daniel Hiltgen's avatar
Daniel Hiltgen committed
27
28
29

	// TODO find a better way to detect iGPU instead of minimum memory
	IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
Daniel Hiltgen's avatar
Daniel Hiltgen committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
)

var (
	// Used to validate if the given ROCm lib is usable
	ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
)

// Gather GPU information from the amdgpu driver if any supported GPUs are detected
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
// and the user hasn't already set this variable
func AMDGetGPUInfo(resp *GpuInfo) {
	// TODO - DRY this out with windows
	if !AMDDetected() {
		return
	}
	skip := map[int]interface{}{}

	// Opportunistic logging of driver version to aid in troubleshooting
	ver, err := AMDDriverVersion()
	if err == nil {
		slog.Info("AMD Driver: " + ver)
	} else {
		// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
		slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
	}

	// If the user has specified exactly which GPUs to use, look up their memory
	visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
	if visibleDevices != "" {
		ids := []int{}
		for _, idStr := range strings.Split(visibleDevices, ",") {
			id, err := strconv.Atoi(idStr)
			if err != nil {
				slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
			} else {
				ids = append(ids, id)
			}
		}
		amdProcMemLookup(resp, nil, ids)
		return
	}

	// Gather GFX version information from all detected cards
	gfx := AMDGFXVersions()
	verStrings := []string{}
	for i, v := range gfx {
		verStrings = append(verStrings, v.ToGFXString())
		if v.Major == 0 {
			// Silently skip CPUs
			skip[i] = struct{}{}
			continue
		}
		if v.Major < 9 {
			// TODO consider this a build-time setting if we can support 8xx family GPUs
			slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
			skip[i] = struct{}{}
		}
	}
	slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))

	// Abort if all GPUs are skipped
	if len(skip) >= len(gfx) {
		slog.Info("all detected amdgpus are skipped, falling back to CPU")
		return
	}

	// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
	libDir, err := AMDValidateLibDir()
	if err != nil {
		slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
		return
	}

103
104
	updateLibPath(libDir)

Daniel Hiltgen's avatar
Daniel Hiltgen committed
105
106
107
108
109
110
111
112
113
114
115
116
117
	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
	if gfxOverride == "" {
		supported, err := GetSupportedGFX(libDir)
		if err != nil {
			slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
			return
		}
		slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))

		for i, v := range gfx {
			if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
				slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
				// TODO - consider discrete markdown just for ROCM troubleshooting?
Michael Yang's avatar
Michael Yang committed
118
				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
Daniel Hiltgen's avatar
Daniel Hiltgen committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
				skip[i] = struct{}{}
			} else {
				slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
			}
		}
	} else {
		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
	}

	if len(skip) >= len(gfx) {
		slog.Info("all detected amdgpus are skipped, falling back to CPU")
		return
	}

	ids := make([]int, len(gfx))
	i := 0
	for k := range gfx {
		ids[i] = k
		i++
	}
	amdProcMemLookup(resp, skip, ids)
	if resp.memInfo.DeviceCount == 0 {
		return
	}
	if len(skip) > 0 {
		amdSetVisibleDevices(ids, skip)
	}
}

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
func updateLibPath(libDir string) {
	ldPaths := []string{}
	if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
		ldPaths = strings.Split(val, ":")
	}
	for _, d := range ldPaths {
		if d == libDir {
			return
		}
	}
	val := strings.Join(append(ldPaths, libDir), ":")
	slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
	os.Setenv("LD_LIBRARY_PATH", val)
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
163
164
165
166
167
168
// Walk the sysfs nodes for the available GPUs and gather information from them
// skipping over any devices in the skip map
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
	resp.memInfo.DeviceCount = 0
	resp.memInfo.TotalMemory = 0
	resp.memInfo.FreeMemory = 0
Daniel Hiltgen's avatar
Daniel Hiltgen committed
169
	slog.Debug("discovering VRAM for amdgpu devices")
Daniel Hiltgen's avatar
Daniel Hiltgen committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
	if len(ids) == 0 {
		entries, err := os.ReadDir(AMDNodesSysfsDir)
		if err != nil {
			slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
			return
		}
		for _, node := range entries {
			if !node.IsDir() {
				continue
			}
			id, err := strconv.Atoi(node.Name())
			if err != nil {
				slog.Warn("malformed amdgpu sysfs node id " + node.Name())
				continue
			}
			ids = append(ids, id)
		}
	}
Daniel Hiltgen's avatar
Daniel Hiltgen committed
188
	slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
Daniel Hiltgen's avatar
Daniel Hiltgen committed
189
190
191
192
193
194
195

	for _, id := range ids {
		if _, skipped := skip[id]; skipped {
			continue
		}
		totalMemory := uint64(0)
		usedMemory := uint64(0)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
196
197
		// Adjust for sysfs vs HIP ids
		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
		propFiles, err := filepath.Glob(propGlob)
		if err != nil {
			slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
		}
		// 1 or more memory banks - sum the values of all of them
		for _, propFile := range propFiles {
			fp, err := os.Open(propFile)
			if err != nil {
				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
				continue
			}
			defer fp.Close()
			scanner := bufio.NewScanner(fp)
			for scanner.Scan() {
				line := strings.TrimSpace(scanner.Text())
				if strings.HasPrefix(line, "size_in_bytes") {
					ver := strings.Fields(line)
					if len(ver) != 2 {
						slog.Warn("malformed " + line)
						continue
					}
					bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
					if err != nil {
						slog.Warn("malformed int " + line)
						continue
					}
					totalMemory += bankSizeInBytes
				}
			}
		}
		if totalMemory == 0 {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
229
230
231
232
233
234
235
			slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
			skip[id] = struct{}{}
			continue
		}
		if totalMemory < IGPUMemLimit {
			slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
			skip[id] = struct{}{}
Daniel Hiltgen's avatar
Daniel Hiltgen committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
			continue
		}
		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
		usedFiles, err := filepath.Glob(usedGlob)
		if err != nil {
			slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
			continue
		}
		for _, usedFile := range usedFiles {
			fp, err := os.Open(usedFile)
			if err != nil {
				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
				continue
			}
			defer fp.Close()
			data, err := io.ReadAll(fp)
			if err != nil {
				slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
				continue
			}
			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
			if err != nil {
				slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
				continue
			}
			usedMemory += used
		}
Daniel Hiltgen's avatar
Daniel Hiltgen committed
263
264
		slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
		slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory  %dM", id, (totalMemory-usedMemory)/1024/1024))
Daniel Hiltgen's avatar
Daniel Hiltgen committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
		resp.memInfo.DeviceCount++
		resp.memInfo.TotalMemory += totalMemory
		resp.memInfo.FreeMemory += (totalMemory - usedMemory)
	}
	if resp.memInfo.DeviceCount > 0 {
		resp.Library = "rocm"
	}
}

// Quick check for AMD driver so we can skip amdgpu discovery if not present
func AMDDetected() bool {
	// Some driver versions (older?) don't have a version file, so just lookup the parent dir
	sysfsDir := filepath.Dir(DriverVersionFile)
	_, err := os.Stat(sysfsDir)
	if errors.Is(err, os.ErrNotExist) {
		slog.Debug("amdgpu driver not detected " + sysfsDir)
		return false
	} else if err != nil {
		slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
		return false
	}
	return true
}

func setupLink(source, target string) error {
	if err := os.RemoveAll(target); err != nil {
		return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
	}
	if err := os.Symlink(source, target); err != nil {
		return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
	}
	slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
	return nil
}

// Ensure the AMD rocm lib dir is wired up
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
// failing that, tell the user how to download it on their own
func AMDValidateLibDir() (string, error) {
	// We rely on the rpath compiled into our library to find rocm
	// so we establish a symlink to wherever we find it on the system
306
307
	// to <payloads>/rocm
	payloadsDir, err := PayloadsDir()
Daniel Hiltgen's avatar
Daniel Hiltgen committed
308
	if err != nil {
309
		return "", err
Daniel Hiltgen's avatar
Daniel Hiltgen committed
310
	}
311
312

	// If we already have a rocm dependency wired, nothing more to do
313
	rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
Daniel Hiltgen's avatar
Daniel Hiltgen committed
314
315
316
	if rocmLibUsable(rocmTargetDir) {
		return rocmTargetDir, nil
	}
317

318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
	// next to the running binary
	exe, err := os.Executable()
	if err == nil {
		peerDir := filepath.Dir(exe)
		if rocmLibUsable(peerDir) {
			slog.Debug("detected ROCM next to ollama executable " + peerDir)
			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
		}
		peerDir = filepath.Join(filepath.Dir(exe), "rocm")
		if rocmLibUsable(peerDir) {
			slog.Debug("detected ROCM next to ollama executable " + peerDir)
			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
		}
	}

333
334
335
336
	// Well known ollama installer path
	installedRocmDir := "/usr/share/ollama/lib/rocm"
	if rocmLibUsable(installedRocmDir) {
		return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
	}

	// Prefer explicit HIP env var
	hipPath := os.Getenv("HIP_PATH")
	if hipPath != "" {
		hipLibDir := filepath.Join(hipPath, "lib")
		if rocmLibUsable(hipLibDir) {
			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
			return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
		}
	}

	// Scan the library path for potential matches
	ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
	for _, ldPath := range ldPaths {
		d, err := filepath.Abs(ldPath)
		if err != nil {
			continue
		}
		if rocmLibUsable(d) {
			return rocmTargetDir, setupLink(d, rocmTargetDir)
		}
	}

	// Well known location(s)
	if rocmLibUsable("/opt/rocm/lib") {
		return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
	}

366
367
	// If we still haven't found a usable rocm, the user will have to install it on their own
	slog.Warn("amdgpu detected, but no compatible rocm library found.  Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install")
Daniel Hiltgen's avatar
Daniel Hiltgen committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

func AMDDriverVersion() (string, error) {
	_, err := os.Stat(DriverVersionFile)
	if err != nil {
		return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err)
	}
	fp, err := os.Open(DriverVersionFile)
	if err != nil {
		return "", err
	}
	defer fp.Close()
	verString, err := io.ReadAll(fp)
	if err != nil {
		return "", err
	}
	return strings.TrimSpace(string(verString)), nil
}

func AMDGFXVersions() map[int]Version {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
389
390
	// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
	// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
Daniel Hiltgen's avatar
Daniel Hiltgen committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
	res := map[int]Version{}
	matches, _ := filepath.Glob(GPUPropertiesFileGlob)
	for _, match := range matches {
		fp, err := os.Open(match)
		if err != nil {
			slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
			continue
		}
		defer fp.Close()
		i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
		if err != nil {
			slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
			continue
		}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
406
407
408
409
410
411
412
		if i == 0 {
			// Skipping the CPU
			continue
		}
		// Align with HIP IDs (zero is first GPU, not CPU)
		i -= 1

Daniel Hiltgen's avatar
Daniel Hiltgen committed
413
414
415
416
417
418
		scanner := bufio.NewScanner(fp)
		for scanner.Scan() {
			line := strings.TrimSpace(scanner.Text())
			if strings.HasPrefix(line, "gfx_target_version") {
				ver := strings.Fields(line)
				if len(ver) != 2 || len(ver[1]) < 5 {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
419
					if ver[1] != "0" {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
						slog.Debug("malformed " + line)
					}
					res[i] = Version{
						Major: 0,
						Minor: 0,
						Patch: 0,
					}
					continue
				}
				l := len(ver[1])
				patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
				minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
				major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
				if err1 != nil || err2 != nil || err3 != nil {
					slog.Debug("malformed int " + line)
					continue
				}

				res[i] = Version{
					Major: uint(major),
					Minor: uint(minor),
					Patch: uint(patch),
				}
			}
		}
	}
	return res
}

func (v Version) ToGFXString() string {
	return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
}