gpu.go 3.9 KB
Newer Older
1
2
3
4
5
//go:build linux || windows

package gpu

/*
6
7
8
#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
#cgo windows LDFLAGS: -lpthread

9
10
11
12
13
14
15
#include "gpu_info.h"

*/
import "C"
import (
	"fmt"
	"log"
16
	"runtime"
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
	"sync"
	"unsafe"

	"github.com/jmorganca/ollama/api"
)

type handles struct {
	cuda *C.cuda_handle_t
	rocm *C.rocm_handle_t
}

var gpuMutex sync.Mutex
var gpuHandles *handles = nil

// Note: gpuMutex must already be held
func initGPUHandles() {
33
	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
	log.Printf("Detecting GPU type")
	gpuHandles = &handles{nil, nil}
	var resp C.cuda_init_resp_t
	C.cuda_init(&resp)
	if resp.err != nil {
		log.Printf("CUDA not detected: %s", C.GoString(resp.err))
		C.free(unsafe.Pointer(resp.err))

		var resp C.rocm_init_resp_t
		C.rocm_init(&resp)
		if resp.err != nil {
			log.Printf("ROCm not detected: %s", C.GoString(resp.err))
			C.free(unsafe.Pointer(resp.err))
		} else {
			log.Printf("Radeon GPU detected")
			rocm := resp.rh
			gpuHandles.rocm = &rocm
		}
	} else {
		log.Printf("Nvidia GPU detected")
		cuda := resp.ch
		gpuHandles.cuda = &cuda
	}
}

func GetGPUInfo() GpuInfo {
	// TODO - consider exploring lspci (and equivalent on windows) to check for
	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
	gpuMutex.Lock()
	defer gpuMutex.Unlock()
	if gpuHandles == nil {
		initGPUHandles()
	}

	var memInfo C.mem_info_t
69
	resp := GpuInfo{}
70
71
	if gpuHandles.cuda != nil {
		C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
72
73
74
75
		if memInfo.err != nil {
			log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
			C.free(unsafe.Pointer(memInfo.err))
		} else {
76
			resp.Library = "cuda"
77
		}
78
79
	} else if gpuHandles.rocm != nil {
		C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
80
81
82
83
		if memInfo.err != nil {
			log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
			C.free(unsafe.Pointer(memInfo.err))
		} else {
84
			resp.Library = "rocm"
85
86
		}
	}
87
	if resp.Library == "" {
88
		C.cpu_check_ram(&memInfo)
89
		// In the future we may offer multiple CPU variants to tune CPU features
90
91
92
93
94
		if runtime.GOOS == "windows" {
			resp.Library = "cpu"
		} else {
			resp.Library = "default"
		}
95
96
	}
	if memInfo.err != nil {
97
		log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
98
		C.free(unsafe.Pointer(memInfo.err))
99
		return resp
100
101
102
103
104
105
	}
	resp.FreeMemory = uint64(memInfo.free)
	resp.TotalMemory = uint64(memInfo.total)
	return resp
}

106
107
108
109
110
111
112
113
114
115
116
117
118
func getCPUMem() (memInfo, error) {
	var ret memInfo
	var info C.mem_info_t
	C.cpu_check_ram(&info)
	if info.err != nil {
		defer C.free(unsafe.Pointer(info.err))
		return ret, fmt.Errorf(C.GoString(info.err))
	}
	ret.FreeMemory = uint64(info.free)
	ret.TotalMemory = uint64(info.total)
	return ret, nil
}

119
120
func CheckVRAM() (int64, error) {
	gpuInfo := GetGPUInfo()
121
	if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
122
123
124
125
126
127
128
129
130
131
		return int64(gpuInfo.FreeMemory), nil
	}
	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
}

func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
	if opts.NumGPU != -1 {
		return opts.NumGPU
	}
	info := GetGPUInfo()
132
	if info.Library == "cpu" || info.Library == "default" {
133
134
135
136
137
138
139
140
141
142
143
144
145
		return 0
	}

	/*
		Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
		We can store the model weights and the kv cache in vram,
		to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
	*/
	bytesPerLayer := uint64(fileSizeBytes / numLayer)

	// 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
	layers := int(info.FreeMemory/bytesPerLayer) * 3 / 4

146
	log.Printf("%d MB VRAM available, loading up to %d %s GPU layers out of %d", info.FreeMemory/(1024*1024), layers, info.Library, numLayer)
147
148
149

	return layers
}