gpu.go 3.79 KB
Newer Older
1
2
3
4
5
//go:build linux || windows

package gpu

/*
6
7
8
#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
#cgo windows LDFLAGS: -lpthread

9
10
11
12
13
14
15
#include "gpu_info.h"

*/
import "C"
import (
	"fmt"
	"log"
16
	"runtime"
17
18
19
20
21
22
23
24
25
26
27
28
	"sync"
	"unsafe"
)

type handles struct {
	cuda *C.cuda_handle_t
	rocm *C.rocm_handle_t
}

var gpuMutex sync.Mutex
var gpuHandles *handles = nil

29
30
// With our current CUDA compile flags, 5.2 and older will not work properly
const CudaComputeMajorMin = 6
31

32
33
// Note: gpuMutex must already be held
func initGPUHandles() {
34
	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
	log.Printf("Detecting GPU type")
	gpuHandles = &handles{nil, nil}
	var resp C.cuda_init_resp_t
	C.cuda_init(&resp)
	if resp.err != nil {
		log.Printf("CUDA not detected: %s", C.GoString(resp.err))
		C.free(unsafe.Pointer(resp.err))

		var resp C.rocm_init_resp_t
		C.rocm_init(&resp)
		if resp.err != nil {
			log.Printf("ROCm not detected: %s", C.GoString(resp.err))
			C.free(unsafe.Pointer(resp.err))
		} else {
			log.Printf("Radeon GPU detected")
			rocm := resp.rh
			gpuHandles.rocm = &rocm
		}
	} else {
		log.Printf("Nvidia GPU detected")
		cuda := resp.ch
		gpuHandles.cuda = &cuda
	}
}

func GetGPUInfo() GpuInfo {
	// TODO - consider exploring lspci (and equivalent on windows) to check for
	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
	gpuMutex.Lock()
	defer gpuMutex.Unlock()
	if gpuHandles == nil {
		initGPUHandles()
	}

	var memInfo C.mem_info_t
70
	resp := GpuInfo{}
71
72
	if gpuHandles.cuda != nil {
		C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
73
74
75
76
		if memInfo.err != nil {
			log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
			C.free(unsafe.Pointer(memInfo.err))
		} else {
77
78
79
80
81
82
83
84
85
86
87
88
			// Verify minimum compute capability
			var cc C.cuda_compute_capability_t
			C.cuda_compute_capability(*gpuHandles.cuda, &cc)
			if cc.err != nil {
				log.Printf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err))
				C.free(unsafe.Pointer(cc.err))
			} else if cc.major >= CudaComputeMajorMin {
				log.Printf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)
				resp.Library = "cuda"
			} else {
				log.Printf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)
			}
89
		}
90
91
	} else if gpuHandles.rocm != nil {
		C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
92
93
94
95
		if memInfo.err != nil {
			log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
			C.free(unsafe.Pointer(memInfo.err))
		} else {
96
			resp.Library = "rocm"
97
98
		}
	}
99
	if resp.Library == "" {
100
		C.cpu_check_ram(&memInfo)
101
		// In the future we may offer multiple CPU variants to tune CPU features
102
103
104
105
106
		if runtime.GOOS == "windows" {
			resp.Library = "cpu"
		} else {
			resp.Library = "default"
		}
107
108
	}
	if memInfo.err != nil {
109
		log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
110
		C.free(unsafe.Pointer(memInfo.err))
111
		return resp
112
113
114
115
116
117
	}
	resp.FreeMemory = uint64(memInfo.free)
	resp.TotalMemory = uint64(memInfo.total)
	return resp
}

118
119
120
121
122
123
124
125
126
127
128
129
130
func getCPUMem() (memInfo, error) {
	var ret memInfo
	var info C.mem_info_t
	C.cpu_check_ram(&info)
	if info.err != nil {
		defer C.free(unsafe.Pointer(info.err))
		return ret, fmt.Errorf(C.GoString(info.err))
	}
	ret.FreeMemory = uint64(info.free)
	ret.TotalMemory = uint64(info.total)
	return ret, nil
}

131
132
func CheckVRAM() (int64, error) {
	gpuInfo := GetGPUInfo()
133
	if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
134
135
136
137
138
139
		// leave 10% or 384Mi of VRAM free for unaccounted for overhead
		overhead := gpuInfo.FreeMemory / 10
		if overhead < 384*1024*1024 {
			overhead = 384 * 1024 * 1024
		}
		return int64(gpuInfo.FreeMemory - overhead), nil
140
141
	}

142
	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
143
}