hydcu.go 7.9 KB
Newer Older
songlinfeng's avatar
songlinfeng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
package hydcu

import (
	"bufio"
	"fmt"
	"io/ioutil"
	"math"
	"os"
	"os/exec"
	"path/filepath"
	"regexp"
	"sort"
	"strconv"
	"strings"
)

type DeviceInfo struct {
	DrmDevices    []string
	PartitionType string
}

type FileSystem interface {
	Stat(name string) (os.FileInfo, error)
	Glob(pattern string) ([]string, error)
	ReadFile(name string) ([]byte, error)
	GetDeviceStat(dev string, format string) (string, error)
}

type DefaultFS struct{}

var defaultFS FileSystem = &DefaultFS{}

func (fs *DefaultFS) Stat(name string) (os.FileInfo, error) { return os.Stat(name) }

func (fs *DefaultFS) Glob(pattern string) ([]string, error) { return filepath.Glob(pattern) }

func (fs *DefaultFS) ReadFile(name string) ([]byte, error) { return os.ReadFile(name) }

func (fs *DefaultFS) GetDeviceStat(dev string, format string) (string, error) {
	out, err := exec.Command("stat", "-c", format, dev).Output()
	if err != nil {
		fmt.Println("stat failed for %v: Error %v", dev, err)
		return "", err
	}
	return strings.TrimSpace(string(out)), nil
}

// GetHYDCUs return the list of all the DCU devices on the system.
func GetHYDCUs() ([]DeviceInfo, error) { return GetHyDCUsWithFS(defaultFS) }

func GetHyDCUsWithFS(fs FileSystem) ([]DeviceInfo, error) {
	if _, err := fs.Stat("/sys/module/hydcu/drivers/"); err != nil {
		return nil, err
	}

	renderDevIds := GetDevIdsFromTopology(fs)

	// Map to store devices by unique_id to maintain grouping
	uniqueIdDevices := make(map[string][]DeviceInfo)
	var uniqueIds []string // To maintain order

	pciDevs, err := fs.Glob("/sys/module/hydcu/drivers/pci:hydcu/[0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F]:*")
	if err != nil {
		fmt.Println("Failed to find hydcu driver directories: %v", err)
		return nil, err
	}

	// Process platform devices for partitions
	platformDevs, _ := fs.Glob("/sys/devices/platform/hydcu_xcp_*")

	// Combine aboth PCI and platform devices
	allDevs := append(pciDevs, platformDevs...)

	//Process all devices using the same logic
	for _, path := range allDevs {
		computePartitionFile := filepath.Join(path, "current_compute_partition")
		memoryPartitionFile := filepath.Join(path, "current_memory_partition")

		computePartitionType, memoryPartitionType, combinedPartitionType := "", "", ""

		// Read the compute partition
		if data, err := ioutil.ReadFile(computePartitionFile); err == nil {
			computePartitionType = strings.ToLower(strings.TrimSpace(string(data)))
		}
		// Read the memory partition
		if data, err := ioutil.ReadFile(memoryPartitionFile); err == nil {
			memoryPartitionType = strings.ToLower(strings.TrimSpace(string(data)))
		}

		combinedPartitionType = computePartitionType + "_" + memoryPartitionType
		if combinedPartitionType == "_" {
			combinedPartitionType = ""
		}

		drms, err := fs.Glob(path + "/drm/*")
		if err != nil {
			return nil, err
		}

		drmDevs := []string{}
		renderMinor := 0
		for _, drm := range drms {
			dev := filepath.Base(drm)
			if len(dev) >= 4 && dev[0:4] == "card" || len(dev) >= 7 && dev[0:7] == "renderD" {
				drmDevs = append(drmDevs, "/dev/dri/"+dev)
				if len(dev) >= 7 && dev[0:7] == "renderD" {
					renderMinor, _ = strconv.Atoi(dev[7:])
				}
			}
		}

		if len(drmDevs) > 0 && renderMinor > 0 {
			if devID, exists := renderDevIds[renderMinor]; exists {
				if _, exists := uniqueIdDevices[devID]; !exists {
					uniqueIds = append(uniqueIds, devID)
				}
				uniqueIdDevices[devID] = append(uniqueIdDevices[devID], DeviceInfo{DrmDevices: drmDevs, PartitionType: combinedPartitionType})
			}
		}
	}

	// Sort devices within each unique_id group by render minor number
	for _, devID := range uniqueIds {
		sort.Slice(uniqueIdDevices[devID], func(i, j int) bool {
			getRenderID := func(devInfo DeviceInfo) int {
				devs := devInfo.DrmDevices
				for _, dev := range devs {
					baseDev := filepath.Base(dev)
					if len(baseDev) >= 7 && strings.HasPrefix(baseDev, "renderD") {
						id, _ := strconv.Atoi(strings.TrimPrefix(baseDev, "renderD"))
						return id
					}
				}
				return 0
			}
			return getRenderID(uniqueIdDevices[devID][i]) < getRenderID(uniqueIdDevices[devID][j])
		})
	}

	// Combine all devices maintaining the unique_id order
	var devs []DeviceInfo
	for _, devID := range uniqueIds {
		devs = append(devs, uniqueIdDevices[devID]...)
	}

	return devs, nil
}

var topoUniqueIdRe = regexp.MustCompile(`unique_id\s(\d+)`)
var renderMinorRe = regexp.MustCompile(`drm_render_minor\s(\d+)`)

// GetDevIdsFromTopology returns a map of render minor numbers to unique_ids
func GetDevIdsFromTopology(fs FileSystem, topoRootParam ...string) map[int]string {
	topoRoot := "/sys/class/kfd/kfd"
	if len(topoRootParam) == 1 {
		topoRoot = topoRootParam[0]
	}

	renderDevIds := make(map[int]string)
	nodeFiles, err := fs.Glob(topoRoot + "/topology/nodes/*/properties")
	if err != nil {
		return renderDevIds
	}

	for _, nodeFile := range nodeFiles {
		renderMinor, err := ParseTopologyProperties(fs, nodeFile, renderMinorRe)
		if err != nil {
			continue
		}

		if renderMinor <= 0 || renderMinor > math.MaxInt32 {
			continue
		}

		devID, err := ParseTopologyPropertiesString(fs, nodeFile, topoUniqueIdRe)
		if err != nil {
			continue
		}

		renderDevIds[int(renderMinor)] = devID
	}

	return renderDevIds
}

// ParseTopologyProperties parses for a property value in kfd topology file as int64
// The format is usually one entry per line <name> <value>.
func ParseTopologyProperties(fs FileSystem, path string, re *regexp.Regexp) (int64, error) {
	content, err := fs.ReadFile(path)
	if err != nil {
		return 0, err
	}

	scanner := bufio.NewScanner(strings.NewReader(string(content)))
	for scanner.Scan() {
		matches := re.FindStringSubmatch(scanner.Text())
		if matches != nil {
			return strconv.ParseInt(matches[1], 0, 64)
		}
	}

	return 0, fmt.Errorf("property not found in %s", path)
}

// ParseTopologyPropertiesString parses for a property value in kfd topology file as string
// The format is usually one entry per line <name> <value>.
func ParseTopologyPropertiesString(fs FileSystem, path string, re *regexp.Regexp) (string, error) {
	content, err := fs.ReadFile(path)
	if err != nil {
		return "", err
	}

	scanner := bufio.NewScanner(strings.NewReader(string(content)))
	for scanner.Scan() {
		matches := re.FindStringSubmatch(scanner.Text())
		if matches != nil {
			return matches[1], nil
		}
	}

	return "", fmt.Errorf("property not found in %s", path)
}

// GetUniqueIdToDeviceIndexMap returns a map of unique_id (as hex string) to device indices
func GetUniqueIdToDeviceIndexMap() (map[string][]int, error) {
	return GetUniqueIdToDeviceIndexMapWithFS(defaultFS)
}

// GetUniqueIdToDeviceIndexMapWithFS creates a mapping from unique_id (hex format) to device indices
func GetUniqueIdToDeviceIndexMapWithFS(fs FileSystem) (map[string][]int, error) {
	devs, err := GetHyDCUsWithFS(fs)
	if err != nil {
		return nil, err
	}

	renderDevIds := GetDevIdsFromTopology(fs)
	uniqueIdToIndex := make(map[string][]int)

	// Process each device group and assign index
	for deviceIndex, deviceGroup := range devs {
		// Find the render minor for this device group
		for _, device := range deviceGroup.DrmDevices {
			// Extract render minor from device path like /dev/dri/renderD128
			if strings.Contains(device, "renderD") {
				renderMinorStr := strings.TrimPrefix(filepath.Base(device), "renderD")
				if renderMinor, err := strconv.Atoi(renderMinorStr); err == nil {
					if uniqueId, exists := renderDevIds[renderMinor]; exists {
						// Convert decimal unique_id to hex format (without 0x prefix)
						if uniqueIdInt, err := strconv.ParseUint(uniqueId, 10, 64); err == nil {
							hexUniqueId := fmt.Sprintf("0x%x", uniqueIdInt)
							uniqueIdToIndex[hexUniqueId] = append(uniqueIdToIndex[hexUniqueId], deviceIndex)
							// Also support without 0x prefix
							hexUniqueIdNoPrefix := fmt.Sprintf("%x", uniqueIdInt)
							uniqueIdToIndex[hexUniqueIdNoPrefix] = append(uniqueIdToIndex[hexUniqueIdNoPrefix], deviceIndex)
						}
					}
				}
				break // Only need one render device per group
			}
		}
	}

	return uniqueIdToIndex, nil
}