graphics.go 3.34 KB
Newer Older
songlinfeng's avatar
songlinfeng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/**
# Copyright (c) 2024, HCUOpt CORPORATION.  All rights reserved.
**/

package discover

import (
	"dtk-container-toolkit/internal/info/drm"
	"dtk-container-toolkit/internal/logger"
	"dtk-container-toolkit/internal/lookup"
	"dtk-container-toolkit/internal/lookup/root"
	"fmt"
	"os"
)

// NewDRMNodesDiscoverer returns a discoverer for the DRM device nodes associated with the specified visible devices.
//
// TODO: The logic for creating DRM devices should be consolidated between this
// and the logic for generating CDI specs for a single device. This is only used
// when applying OCI spec modifications to an incoming spec in "legacy" mode.
func NewDRMNodesDiscoverer(logger logger.Interface, busIds []string, requestBusIds []string, devRoot string) (Discover, error) {
	drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, busIds, requestBusIds, devRoot)
	if err != nil {
		return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
	}

	return drmDeviceNodes, nil
}

// newDRMDeviceDiscoverer creates a discoverer for the DRM devices associated with the requested devices.
func newDRMDeviceDiscoverer(logger logger.Interface, busIds []string, requestBusIds []string, devRoot string) (Discover, error) {
	allDevices := NewCharDeviceDiscoverer(
		logger,
		devRoot,
		[]string{
			"/dev/dri/card*",
			"/dev/dri/renderD*",
		},
	)

	filter := make(selectDeviceByPath)
	for _, busId := range requestBusIds {
		drmDeviceNodes, err := drm.GetDeviceNodesByBusID(busId)
		if err != nil {
			return nil, fmt.Errorf("failed to determine DRM devices for %v: %v", busId, err)
		}
		logger.Infof("selected drm nodes from bus %s: %v", busId, drmDeviceNodes)
		for _, drmDeviceNode := range drmDeviceNodes {
			filter[drmDeviceNode] = true
		}
	}

	// We return a discoverer that applies the DRM device filter created above to all discovered DRM device nodes.
	d := newFilteredDiscoverer(
		logger,
		allDevices,
		filter,
	)

	return d, nil
}

// selectDeviceByPath is a filter that allows devices to be selected by the path
type selectDeviceByPath map[string]bool

var _ Filter = (*selectDeviceByPath)(nil)

// DeviceIsSelected determines whether the device's path has been selected
func (s selectDeviceByPath) DeviceIsSelected(device Device) bool {
	return s[device.Path]
}

// MountIsSelected is always true
func (s selectDeviceByPath) MountIsSelected(Mount) bool {
	return true
}

// HookIsSelected is always true
func (s selectDeviceByPath) HookIsSelected(Hook) bool {
	return true
}

// NewCommonHCUDiscoverer creates a discoverer for the mounts required by HCU.
func NewCommonHCUDiscoverer(logger logger.Interface, dtkCDIHookPath string, driver *root.Driver, isMount bool) (Discover, error) {
	metaDevices := NewCharDeviceDiscoverer(
		logger,
		driver.Root,
		[]string{
			"/dev/kfd",
			"/dev/mkfd",
			"/dev/mem",
		},
	)

	var directory []string
	if isMount {
		directory = append(directory, "hyhal")
	}
	libraries := NewMounts(
		logger,
		lookup.NewDirectoryLocator(
			lookup.WithLogger(logger),
			lookup.WithCount(1),
			lookup.WithSearchPaths("/usr/local", "/opt"),
		),
		driver.Root,
		directory,
	)

	var linkHook Hook

	info, err := os.Stat("/usr/local/hyhal")
	if err == nil && info.IsDir() {
		linkHook = CreateSymlinkHook(dtkCDIHookPath, []string{"/usr/local/hyhal::/opt/hyhal"})
	}

	d := Merge(
		metaDevices,
		libraries,
		NewUserGroupDiscover(logger),
		linkHook,
	)
	return d, nil
}