Commit d7e13eb9 authored by songlinfeng's avatar songlinfeng
Browse files

add dtk-container-toolkit

parent fcdba4f3
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package lookup
import (
"dtk-container-toolkit/internal/ldcache"
"dtk-container-toolkit/internal/logger"
"fmt"
)
type ldcacheLocator struct {
logger logger.Interface
cache ldcache.LDCache
}
var _ Locator = (*ldcacheLocator)(nil)
// NewLibraryLocator creates a library locator using the specified options.
func NewLibraryLocator(opts ...Option) Locator {
b := newBuilder(opts...)
// If search paths are already specified, we return a locator for the specified search paths.
if len(b.searchPaths) > 0 {
return NewSymlinkLocator(
WithLogger(b.logger),
WithSearchPaths(b.searchPaths...),
WithRoot("/"),
)
}
opts = append(opts,
WithSearchPaths([]string{
"/",
"/opt/hyhal/lib",
"/usr/lib64",
"/usr/lib/x86_64-linux-gnu",
"/usr/lib/aarch64-linux-gnu",
"/lib64",
"/lib/x86_64-linux-gnu",
"/lib/aarch64-linux-gnu",
}...),
)
// We construct a symlink locator for expected library locations.
symlinkLocator := NewSymlinkLocator(opts...)
l := First(
symlinkLocator,
newLdcacheLocator(opts...),
)
return l
}
func newLdcacheLocator(opts ...Option) Locator {
b := newBuilder(opts...)
cache, err := ldcache.New(b.logger, b.root)
if err != nil {
// If we failed to open the LDCache, we default to a symlink locator.
b.logger.Warningf("Failed to load ldcache: %v", err)
return nil
}
return &ldcacheLocator{
logger: b.logger,
cache: cache,
}
}
// Locate finds the specified libraryname.
// If the input is a library name, the ldcache is searched otherwise the
// provided path is resolved as a symlink.
func (l ldcacheLocator) Locate(libname string) ([]string, error) {
paths32, paths64 := l.cache.Lookup(libname)
if len(paths32) > 0 {
l.logger.Warningf("Ignoring 32-bit libraries for %v: %v", libname, paths32)
}
if len(paths64) == 0 {
return nil, fmt.Errorf("64-bit library %v: %w", libname, ErrNotFound)
}
return paths64, nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package lookup
import "errors"
//go:generate moq -stub -out locator_mock.go . Locator
// Locator defines the interface for locating files on a system.
type Locator interface {
Locate(string) ([]string, error)
}
// ErrNotFound indicates that a specified pattern or file could not be found.
var ErrNotFound = errors.New("not found")
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package lookup
import (
"sync"
)
// Ensure, that LocatorMock does implement Locator.
// If this is not the case, regenerate this file with moq.
var _ Locator = &LocatorMock{}
// LocatorMock is a mock implementation of Locator.
//
// func TestSomethingThatUsesLocator(t *testing.T) {
//
// // make and configure a mocked Locator
// mockedLocator := &LocatorMock{
// LocateFunc: func(s string) ([]string, error) {
// panic("mock out the Locate method")
// },
// }
//
// // use mockedLocator in code that requires Locator
// // and then make assertions.
//
// }
type LocatorMock struct {
// LocateFunc mocks the Locate method.
LocateFunc func(s string) ([]string, error)
// calls tracks calls to the methods.
calls struct {
// Locate holds details about calls to the Locate method.
Locate []struct {
// S is the s argument value.
S string
}
}
lockLocate sync.RWMutex
}
// Locate calls LocateFunc.
func (mock *LocatorMock) Locate(s string) ([]string, error) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockLocate.Lock()
mock.calls.Locate = append(mock.calls.Locate, callInfo)
mock.lockLocate.Unlock()
if mock.LocateFunc == nil {
var (
stringsOut []string
errOut error
)
return stringsOut, errOut
}
return mock.LocateFunc(s)
}
// LocateCalls gets all the calls that were made to Locate.
// Check the length with:
//
// len(mockedLocator.LocateCalls())
func (mock *LocatorMock) LocateCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLocate.RLock()
calls = mock.calls.Locate
mock.lockLocate.RUnlock()
return calls
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package lookup
import (
"errors"
)
type first []Locator
// First returns a locator that returns the first non-empty match
func First(locators ...Locator) Locator {
var f first
for _, l := range locators {
if l == nil {
continue
}
f = append(f, l)
}
return f
}
// Locate returns the results for the first locator that returns a non-empty non-error result.
func (f first) Locate(pattern string) ([]string, error) {
var allErrors []error
for _, l := range f {
if l == nil {
continue
}
candidates, err := l.Locate(pattern)
if err != nil {
allErrors = append(allErrors, err)
continue
}
if len(candidates) > 0 {
return candidates, nil
}
}
return nil, errors.Join(allErrors...)
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package lookup
import (
"os"
"path"
"path/filepath"
"strings"
)
const (
envPath = "PATH"
)
var (
defaultPath = []string{"/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"}
)
// GetPaths returns a list of paths for a specified root. These are constructed from the
// PATH environment variable, a default path list, and the supplied root.
func GetPaths(root string) []string {
dirs := filepath.SplitList(os.Getenv(envPath))
inDirs := make(map[string]bool)
for _, d := range dirs {
inDirs[d] = true
}
// directories from the environment have higher precedence
for _, d := range defaultPath {
if inDirs[d] {
// We don't add paths that are already included
continue
}
dirs = append(dirs, d)
}
if root != "" && root != "/" {
rootDirs := []string{}
for _, dir := range dirs {
rootDirs = append(rootDirs, path.Join(root, dir))
}
// directories with the root prefix have higher precedence
dirs = append(rootDirs, dirs...)
}
return dirs
}
// GetPath returns a colon-separated path value that can be used to set the PATH
// environment variable
func GetPath(root string) string {
return strings.Join(GetPaths(root), ":")
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package root
import "dtk-container-toolkit/internal/logger"
type Option func(*Driver)
func WithLogger(logger logger.Interface) Option {
return func(d *Driver) {
d.logger = logger
}
}
func WithDriverRoot(root string) Option {
return func(d *Driver) {
d.Root = root
}
}
func WithLibrarySearchPaths(paths ...string) Option {
return func(d *Driver) {
d.librarySearchPaths = paths
}
}
func WithConfigSearchPaths(paths ...string) Option {
return func(d *Driver) {
d.configSearchPaths = paths
}
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package root
import (
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/lookup"
"os"
"path/filepath"
)
// Driver represents a filesystem in which a set of drivers or devices is defined.
type Driver struct {
logger logger.Interface
// Root represents the root from the perspective of the driver libraries and binaries.
Root string
// librarySearchPaths specifies explicit search paths for discovering libraries.
librarySearchPaths []string
// configSearchPaths specified explicit search paths for discovering driver config files.
configSearchPaths []string
}
// New creates a new Driver root using the specified options.
func New(opts ...Option) *Driver {
d := &Driver{}
for _, opt := range opts {
opt(d)
}
if d.logger == nil {
d.logger = logger.New()
}
return d
}
// Files returns a Locator for arbitrary driver files.
func (r *Driver) Files(opts ...lookup.Option) lookup.Locator {
return lookup.NewFileLocator(
append(
opts,
lookup.WithLogger(r.logger),
lookup.WithRoot(r.Root),
)...,
)
}
// Libraries returns a Locator for driver libraries.
func (r *Driver) Libraries() lookup.Locator {
return lookup.NewLibraryLocator(
lookup.WithLogger(r.logger),
lookup.WithRoot(r.Root),
lookup.WithSearchPaths(normalizeSearchPaths(r.librarySearchPaths...)...),
)
}
// Configs returns a locator for driver configs.
// If configSearchPaths is specified, these paths are used as absolute paths,
// otherwise, /etc and /usr/share are searched.
func (r *Driver) Configs() lookup.Locator {
return lookup.NewFileLocator(r.configSearchOptions()...)
}
func (r *Driver) configSearchOptions() []lookup.Option {
if len(r.configSearchPaths) > 0 {
return []lookup.Option{
lookup.WithLogger(r.logger),
lookup.WithRoot("/"),
lookup.WithSearchPaths(normalizeSearchPaths(r.configSearchPaths...)...),
}
}
searchPaths := []string{"/etc"}
searchPaths = append(searchPaths, xdgDataDirs()...)
return []lookup.Option{
lookup.WithLogger(r.logger),
lookup.WithRoot(r.Root),
lookup.WithSearchPaths(searchPaths...),
}
}
// normalizeSearchPaths takes a list of paths and normalized these.
// Each of the elements in the list is expanded if it is a path list and the
// resultant list is returned.
// This allows, for example, for the contents of `PATH` or `LD_LIBRARY_PATH` to
// be passed as a search path directly.
func normalizeSearchPaths(paths ...string) []string {
var normalized []string
for _, path := range paths {
normalized = append(normalized, filepath.SplitList(path)...)
}
return normalized
}
// xdgDataDirs finds the paths as specified in the environment variable XDG_DATA_DIRS.
// See https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html.
func xdgDataDirs() []string {
if dirs, exists := os.LookupEnv("XDG_DATA_DIRS"); exists && dirs != "" {
return normalizeSearchPaths(dirs)
}
return []string{"/usr/local/share", "/usr/share"}
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package lookup
import (
"dtk-container-toolkit/internal/lookup/symlinks"
"fmt"
"path/filepath"
)
type symlinkChain struct {
file
}
type symlink struct {
file
}
// NewSymlinkChainLocator creates a locator that can be used for locating files through symlinks.
func NewSymlinkChainLocator(opts ...Option) Locator {
f := newFileLocator(opts...)
l := symlinkChain{
file: *f,
}
return &l
}
// NewSymlinkLocator creats a locator that can be used for locating files through symlinks.
func NewSymlinkLocator(opts ...Option) Locator {
f := newFileLocator(opts...)
l := symlink{
file: *f,
}
return &l
}
// Locate finds the specified pattern at the specified root.
// If the file is a symlink, the link is followed and all candidates to the final target are returned.
func (p symlinkChain) Locate(pattern string) ([]string, error) {
candidates, err := p.file.Locate(pattern)
if err != nil {
return nil, err
}
if len(candidates) == 0 {
return candidates, nil
}
found := make(map[string]bool)
for len(candidates) > 0 {
candidate := candidates[0]
candidates = candidates[:len(candidates)-1]
if found[candidate] {
continue
}
found[candidate] = true
target, err := symlinks.Resolve(candidate)
if err != nil {
return nil, fmt.Errorf("error resolving symlink: %v", err)
}
if !filepath.IsAbs(target) {
target, err = filepath.Abs(filepath.Join(filepath.Dir(candidate), target))
if err != nil {
return nil, fmt.Errorf("failed to construct absolute path: %v", err)
}
}
p.logger.Debugf("Resolved link: '%v' => '%v'", candidate, target)
if !found[target] {
candidates = append(candidates, target)
}
}
var filenames []string
for f := range found {
filenames = append(filenames, f)
}
return filenames, nil
}
// Locate finds the specified pattern at the specified root.
// If the file is a symlink, the link is resolved and the target returned.
func (p symlink) Locate(pattern string) ([]string, error) {
candidates, err := p.file.Locate(pattern)
if err != nil {
return nil, err
}
var targets []string
seen := make(map[string]bool)
for _, candidate := range candidates {
target, err := filepath.EvalSymlinks(candidate)
if err != nil {
return nil, fmt.Errorf("failed to resolve link: %w", err)
}
if seen[target] {
continue
}
seen[target] = true
targets = append(targets, target)
}
return targets, err
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package symlinks
import (
"fmt"
"os"
)
// Resolve returns the link target of the specified filename or the filename if it is not a link.
func Resolve(filename string) (string, error) {
info, err := os.Lstat(filename)
if err != nil {
return filename, fmt.Errorf("failed to get file info: %v", info)
}
if info.Mode()&os.ModeSymlink == 0 {
return filename, nil
}
return os.Readlink(filename)
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/oci"
"github.com/opencontainers/runtime-spec/specs-go"
)
func NewCapModifier(logger logger.Interface, toAddCaps []string, toRemoveCaps []string) oci.SpecModifier {
return &capModifier{
logger: logger,
toAddCaps: toAddCaps,
toRemoveCaps: toRemoveCaps,
}
}
type capModifier struct {
logger logger.Interface
toAddCaps []string
toRemoveCaps []string
}
var _ oci.SpecModifier = (*capModifier)(nil)
func (m *capModifier) Modify(spec *specs.Spec) error {
if spec == nil || spec.Process == nil || spec.Process.Capabilities == nil {
return nil
}
bounding := spec.Process.Capabilities.Bounding
effective := spec.Process.Capabilities.Effective
permitted := spec.Process.Capabilities.Permitted
if len(m.toAddCaps) != 0 {
for _, cap := range m.toAddCaps {
bounding = append(bounding, cap)
effective = append(effective, cap)
permitted = append(permitted, cap)
}
bounding = unqiueCaps(bounding)
effective = unqiueCaps(effective)
permitted = unqiueCaps(permitted)
}
if len(m.toRemoveCaps) != 0 {
bounding = removeCaps(bounding, m.toRemoveCaps)
effective = removeCaps(effective, m.toRemoveCaps)
permitted = removeCaps(permitted, m.toRemoveCaps)
}
spec.Process.Capabilities.Bounding = bounding
spec.Process.Capabilities.Effective = effective
spec.Process.Capabilities.Permitted = permitted
return nil
}
func unqiueCaps(caps []string) []string {
var uCaps []string
mCaps := make(map[string]bool)
for _, v := range caps {
if _, ok := mCaps[v]; !ok {
mCaps[v] = true
uCaps = append(uCaps, v)
}
}
return uCaps
}
func removeCaps(orgCaps []string, reCaps []string) []string {
toRemove := make(map[string]bool)
for _, cap := range reCaps {
toRemove[cap] = true
}
var caps []string
for _, v := range orgCaps {
if _, ok := toRemove[v]; ok {
continue
}
caps = append(caps, v)
}
return caps
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"dtk-container-toolkit/internal/config"
"dtk-container-toolkit/internal/config/image"
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/modifier/cdi"
"dtk-container-toolkit/internal/oci"
"dtk-container-toolkit/pkg/c3000cdi"
"dtk-container-toolkit/pkg/c3000cdi/spec"
"fmt"
"strings"
"tags.cncf.io/container-device-interface/pkg/parser"
)
// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
// CDI specifications available on the system. The DTK_VISIBLE_DEVICES environment variable is
// used to select the devices to include.
func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
devices, err := getDevicesFromSpec(logger, ociSpec, cfg)
if err != nil {
return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err)
}
if len(devices) == 0 {
logger.Debugf("No devices requested; no modification required.")
return nil, nil
}
logger.Debugf("Creating CDI modifier for devices: %v", devices)
automaticDevices := filterAutomaticDevices(devices)
if len(automaticDevices) != len(devices) && len(automaticDevices) > 0 {
return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.c-3000.com' is not supported when requesting other CDI devices")
}
if len(automaticDevices) > 0 {
automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, automaticDevices)
if err == nil {
return automaticModifier, nil
}
logger.Warningf("Failed to create the automatic CDI modifier: %w", err)
logger.Debugf("Falling back to the standard CDI modifier")
}
return cdi.New(
cdi.WithLogger(logger),
cdi.WithDevices(devices...),
cdi.WithSpecDirs(cfg.DTKContainerRuntimeConfig.Modes.CDI.SpecDirs...),
)
}
func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) {
rawSpec, err := ociSpec.Load()
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
}
annotationDevices, err := getAnnotationDevices(cfg.DTKContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes, rawSpec.Annotations)
if err != nil {
return nil, fmt.Errorf("failed to parse container annotations: %v", err)
}
if len(annotationDevices) > 0 {
return annotationDevices, nil
}
container, err := image.NewDTKImageFromSpec(rawSpec)
if err != nil {
return nil, err
}
if cfg.AcceptDeviceListAsVolumeMounts {
mountDevices := container.CDIDevicesFromMounts()
if len(mountDevices) > 0 {
return mountDevices, nil
}
}
var devices []string
seen := make(map[string]bool)
for _, name := range container.VisibleDevicesFromEnvVar() {
if !parser.IsQualifiedName(name) {
name = fmt.Sprintf("%s=%s", cfg.DTKContainerRuntimeConfig.Modes.CDI.DefaultKind, name)
}
if seen[name] {
logger.Debugf("Ignoring duplicate device %q", name)
continue
}
devices = append(devices, name)
}
if len(devices) == 0 {
return nil, nil
}
if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged(rawSpec) {
return devices, nil
}
logger.Warningf("Ignoring devices specified in DTK_VISIBLE_DEVICES: %v", devices)
return nil, nil
}
// getAnnotationDevices returns a list of devices specified in the annotations.
// Keys starting with the specified prefixes are considered and expected to contain a comma-separated list of
// fully-qualified CDI devices names. If any device name is not fully-quality an error is returned.
// The list of returned devices is deduplicated.
func getAnnotationDevices(prefixes []string, annotations map[string]string) ([]string, error) {
devicesByKey := make(map[string][]string)
for key, value := range annotations {
for _, prefix := range prefixes {
if strings.HasPrefix(key, prefix) {
devicesByKey[key] = strings.Split(value, ",")
}
}
}
seen := make(map[string]bool)
var annotationDevices []string
for key, devices := range devicesByKey {
for _, device := range devices {
if !parser.IsQualifiedName(device) {
return nil, fmt.Errorf("invalid device name %q in annotation %q", device, key)
}
if seen[device] {
continue
}
annotationDevices = append(annotationDevices, device)
seen[device] = true
}
}
return annotationDevices, nil
}
// filterAutomaticDevices searches for "automatic" device names in the input slice.
// "Automatic" devices are a well-defined list of CDI device names which, when requested,
// trigger the generation of a CDI spec at runtime. This removes the need to generate a
// CDI spec on the system a-priori as well as keep it up-to-date.
func filterAutomaticDevices(devices []string) []string {
var automatic []string
for _, device := range devices {
vendor, class, _ := parser.ParseDevice(device)
if vendor == "runtime.c-3000.com" && class == "hcu" {
automatic = append(automatic, device)
}
}
return automatic
}
func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) {
logger.Debugf("Generating in-memory CDI specs for devices %v", devices)
spec, err := generateAutomaticCDISpec(logger, cfg, devices)
if err != nil {
return nil, fmt.Errorf("failed to generate CDI spec: %w", err)
}
cdiModifier, err := cdi.New(
cdi.WithLogger(logger),
cdi.WithSpec(spec.Raw()),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI modifier: %w", err)
}
return cdiModifier, nil
}
func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devices []string) (spec.Interface, error) {
cdilib, err := c3000cdi.New(
c3000cdi.WithLogger(logger),
c3000cdi.WithDTKCDIHookPath(cfg.DTKCTKConfig.Path),
c3000cdi.WithVendor("runtime.c-3000.com"),
c3000cdi.WithClass("hcu"),
)
if err != nil {
return nil, fmt.Errorf("failed to construct CDI library: %w", err)
}
identifiers := []string{}
for _, device := range devices {
_, _, id := parser.ParseDevice(device)
identifiers = append(identifiers, id)
}
deviceSpecs, err := cdilib.GetDeviceSpecsByID(identifiers...)
if err != nil {
return nil, fmt.Errorf("failed to get CDI device specs: %w", err)
}
commonEdits, err := cdilib.GetCommonEdits()
if err != nil {
return nil, fmt.Errorf("failed to get common CDI spec edits: %w", err)
}
return spec.New(
spec.WithDeviceSpecs(deviceSpecs),
spec.WithEdits(*commonEdits.ContainerEdits),
spec.WithVendor("runtime.c-3000.com"),
spec.WithClass("hcu"),
)
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package cdi
import (
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/oci"
"fmt"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"
)
type builder struct {
logger logger.Interface
specDirs []string
devices []string
cdiSpec *specs.Spec
}
// Option represents a functional option for creating a CDI mofifier.
type Option func(*builder)
// New creates a new CDI modifier.
func New(opts ...Option) (oci.SpecModifier, error) {
b := &builder{}
for _, opt := range opts {
opt(b)
}
if b.logger == nil {
b.logger = logger.New()
}
return b.build()
}
// build uses the applied options and constructs a CDI modifier using the builder.
func (m builder) build() (oci.SpecModifier, error) {
if len(m.devices) == 0 && m.cdiSpec == nil {
return nil, nil
}
if m.cdiSpec != nil {
modifier := fromCDISpec{
cdiSpec: &cdi.Spec{Spec: m.cdiSpec},
}
return modifier, nil
}
registry, err := cdi.NewCache(
cdi.WithAutoRefresh(false),
cdi.WithSpecDirs(m.specDirs...),
)
if err != nil {
return nil, fmt.Errorf("failed to create CDI registry: %v", err)
}
modifier := fromRegistry{
logger: m.logger,
registry: registry,
devices: m.devices,
}
return modifier, nil
}
// WithLogger sets the logger for the CDI modifier builder.
func WithLogger(logger logger.Interface) Option {
return func(b *builder) {
b.logger = logger
}
}
// WithSpecDirs sets the spec directories for the CDI modifier builder.
func WithSpecDirs(specDirs ...string) Option {
return func(b *builder) {
b.specDirs = specDirs
}
}
// WithDevices sets the devices for the CDI modifier builder.
func WithDevices(devices ...string) Option {
return func(b *builder) {
b.devices = devices
}
}
// WithSpec sets the spec for the CDI modifier builder.
func WithSpec(spec *specs.Spec) Option {
return func(b *builder) {
b.cdiSpec = spec
}
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package cdi
import (
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/oci"
"errors"
"fmt"
"github.com/opencontainers/runtime-spec/specs-go"
"tags.cncf.io/container-device-interface/pkg/cdi"
)
// fromRegistry represents the modifications performed using a CDI registry.
type fromRegistry struct {
logger logger.Interface
registry *cdi.Cache
devices []string
}
var _ oci.SpecModifier = (*fromRegistry)(nil)
// Modify applies the modifications defined by the CDI registry to the incoming OCI spec.
func (m fromRegistry) Modify(spec *specs.Spec) error {
if err := m.registry.Refresh(); err != nil {
m.logger.Debugf("The following error was triggered when refreshing the CDI registry: %v", err)
}
m.logger.Debugf("Injecting devices using CDI: %v", m.devices)
unresolvedDevices, err := m.registry.InjectDevices(spec, m.devices...)
if unresolvedDevices != nil {
m.logger.Warningf("could not resolve CDI devices: %v", unresolvedDevices)
}
if err != nil {
var refreshErrors []error
for _, rerrs := range m.registry.GetErrors() {
refreshErrors = append(refreshErrors, rerrs...)
}
if rerr := errors.Join(refreshErrors...); rerr != nil {
// We log the errors that may have been generated while refreshing the CDI registry.
// These may be due to malformed specifications or device name conflicts that could be
// the cause of an injection failure.
m.logger.Warningf("Refreshing the CDI registry generated errors: %v", rerr)
}
return fmt.Errorf("failed to inject CDI devices: %v", err)
}
return nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package cdi
import (
"dtk-container-toolkit/internal/oci"
"fmt"
"github.com/opencontainers/runtime-spec/specs-go"
"tags.cncf.io/container-device-interface/pkg/cdi"
)
// fromCDISpec represents the modifications performed from a raw CDI spec.
type fromCDISpec struct {
cdiSpec *cdi.Spec
}
var _ oci.SpecModifier = (*fromCDISpec)(nil)
// Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec.
func (m fromCDISpec) Modify(spec *specs.Spec) error {
for _, device := range m.cdiSpec.Devices {
device := device
cdiDevice := cdi.Device{
Device: &device,
}
if err := cdiDevice.ApplyEdits(spec); err != nil {
return fmt.Errorf("failed to apply edits for device %q: %v", cdiDevice.GetQualifiedName(), err)
}
}
return m.cdiSpec.ApplyEdits(spec)
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"bytes"
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/lookup"
"dtk-container-toolkit/internal/oci"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"syscall"
"github.com/opencontainers/runtime-spec/specs-go"
)
type copyModifier struct {
logger logger.Interface
}
// Gives a number indicating the device driver to be used to access the passed device
func major(device uint64) uint64 {
return (device >> 8) & 0xfff
}
// Gives a number that serves as a flag to the device driver for the passed device
func minor(device uint64) uint64 {
return (device & 0xff) | ((device >> 12) & 0xfff00)
}
func mkdev(major int64, minor int64) uint32 {
return uint32(((minor & 0xfff00) << 12) | ((major & 0xfff) << 8) | (minor & 0xff))
}
func CopyFile(source string, dest string) error {
si, err := os.Lstat(source)
if err != nil {
return err
}
st, ok := si.Sys().(*syscall.Stat_t)
if !ok {
return fmt.Errorf("could not convert to syscall.Stat_t")
}
uid := int(st.Uid)
gid := int(st.Gid)
modeType := si.Mode() & os.ModeType
// Handle symlinks
if modeType == os.ModeSymlink {
target, err := os.Readlink(source)
if err != nil {
return err
}
if _, err := os.Lstat(dest); err == nil {
if err := os.Remove(dest); err != nil {
return fmt.Errorf("failed to remove existing file: %w", err)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to check if destination exists: %w", err)
}
if err := os.Symlink(target, dest); err != nil {
return err
}
}
// Handle device files
if modeType == os.ModeDevice {
devMajor := int64(major(uint64(st.Rdev)))
devMinor := int64(minor(uint64(st.Rdev)))
mode := uint32(si.Mode() & os.ModePerm)
if si.Mode()&os.ModeCharDevice != 0 {
mode |= syscall.S_IFCHR
} else {
mode |= syscall.S_IFBLK
}
if err := syscall.Mknod(dest, mode, int(mkdev(devMajor, devMinor))); err != nil {
return err
}
}
// Handle regular files
if si.Mode().IsRegular() {
err = copyInternal(source, dest)
if err != nil {
return err
}
}
// Chown the file
if err := os.Lchown(dest, uid, gid); err != nil {
return err
}
// Chmod the file
if !(modeType == os.ModeSymlink) {
if err := os.Chmod(dest, si.Mode()); err != nil {
return err
}
}
return nil
}
func copyInternal(source, dest string) (retErr error) {
sf, err := os.Open(source)
if err != nil {
return err
}
defer sf.Close()
df, err := os.Create(dest)
if err != nil {
return err
}
defer func() {
err := df.Close()
if retErr == nil {
retErr = err
}
}()
_, err = io.Copy(df, sf)
return err
}
func RemoveSymlinkOrDirectory(source string) error {
info, err := os.Lstat(source)
if err != nil {
return fmt.Errorf("failed to lstat source: %w", err)
}
if info.Mode()&os.ModeSymlink != 0 {
err := os.Remove(source)
if err != nil {
return fmt.Errorf("failed to remove symlink: %w", err)
}
} else if info.IsDir() {
err := os.RemoveAll(source)
if err != nil {
return fmt.Errorf("failed to remove directory: %w", err)
}
} else {
return fmt.Errorf("source is neither a symlink nor a directory: %s", source)
}
return nil
}
func CopyDirectory(srcDir, dstDir string) error {
RemoveSymlinkOrDirectory(dstDir)
fi, err := os.Stat(srcDir)
if err != nil {
return err
}
st, ok := fi.Sys().(*syscall.Stat_t)
if !ok {
return fmt.Errorf("could not convert to syscall.Stat_t")
}
if err := os.MkdirAll(dstDir, fi.Mode()); err != nil {
return err
}
if err := os.Lchown(dstDir, int(st.Uid), int(st.Gid)); err != nil {
return err
}
if err := os.Chmod(dstDir, fi.Mode()); err != nil {
return err
}
return filepath.Walk(srcDir, func(srcPath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, err := filepath.Rel(srcDir, srcPath)
if err != nil {
return err
}
dstPath := filepath.Join(dstDir, relPath)
if info.IsDir() {
st, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return fmt.Errorf("could not convert to syscall.Stat_t")
}
uid := int(st.Uid)
gid := int(st.Gid)
if err := os.MkdirAll(dstPath, info.Mode()); err != nil {
return err
}
if err := os.Lchown(dstPath, uid, gid); err != nil {
return err
}
if err := os.Chmod(dstPath, info.Mode()); err != nil {
return err
}
return nil
}
return CopyFile(srcPath, dstPath)
})
}
func NewCopyModifier(logger logger.Interface) (oci.SpecModifier, error) {
m := copyModifier{
logger: logger,
}
return &m, nil
}
type VFS interface {
Lstat(name string) (os.FileInfo, error)
Readlink(name string) (string, error)
}
type osVFS struct{}
func (o osVFS) Lstat(name string) (os.FileInfo, error) { return os.Lstat(name) }
func (o osVFS) Readlink(name string) (string, error) { return os.Readlink(name) }
// IsNotExist tells you if err is an error that implies that either the path
// accessed does not exist (or path components don't exist). This is
// effectively a more broad version of os.IsNotExist.
func IsNotExist(err error) bool {
// Check that it's not actually an ENOTDIR, which in some cases is a more
// convoluted case of ENOENT (usually involving weird paths).
return errors.Is(err, os.ErrNotExist) || errors.Is(err, syscall.ENOTDIR) || errors.Is(err, syscall.ENOENT)
}
func SecureJoinVFS(root, unsafePath string, vfs VFS) (string, error) {
// Use the os.* VFS implementation if none was specified.
if vfs == nil {
vfs = osVFS{}
}
unsafePath = filepath.FromSlash(unsafePath)
var path bytes.Buffer
n := 0
for unsafePath != "" {
if n > 255 {
return "", &os.PathError{Op: "SecureJoin", Path: root + string(filepath.Separator) + unsafePath, Err: syscall.ELOOP}
}
if v := filepath.VolumeName(unsafePath); v != "" {
unsafePath = unsafePath[len(v):]
}
// Next path component, p.
i := strings.IndexRune(unsafePath, filepath.Separator)
var p string
if i == -1 {
p, unsafePath = unsafePath, ""
} else {
p, unsafePath = unsafePath[:i], unsafePath[i+1:]
}
// Create a cleaned path, using the lexical semantics of /../a, to
// create a "scoped" path component which can safely be joined to fullP
// for evaluation. At this point, path.String() doesn't contain any
// symlink components.
cleanP := filepath.Clean(string(filepath.Separator) + path.String() + p)
if cleanP == string(filepath.Separator) {
path.Reset()
continue
}
fullP := filepath.Clean(root + cleanP)
// Figure out whether the path is a symlink.
_, err := vfs.Lstat(fullP)
if err != nil && !IsNotExist(err) {
return "", err
}
// Treat non-existent path components the same as non-symlinks (we
// can't do any better here).
path.WriteString(p)
path.WriteRune(filepath.Separator)
}
// We have to clean path.String() here because it may contain '..'
// components that are entirely lexical, but would be misleading otherwise.
// And finally do a final clean to ensure that root is also lexically
// clean.
fullP := filepath.Clean(string(filepath.Separator) + path.String())
return filepath.Clean(root + fullP), nil
}
func (m copyModifier) Modify(spec *specs.Spec) error {
locator := lookup.NewDirectoryLocator(
lookup.WithLogger(m.logger),
lookup.WithCount(1),
lookup.WithSearchPaths("/usr/local", "/opt"),
)
candidate := "hyhal"
located, err := locator.Locate(candidate)
if err != nil {
m.logger.Warningf("Could not locate %v: %v", candidate, err)
return nil
}
if len(located) == 0 {
m.logger.Warningf("Missing %v", candidate)
return nil
}
m.logger.Debugf("Located %v as %v", candidate, located)
for _, path := range located {
container_path, err := SecureJoinVFS(spec.Root.Path, path, nil)
if err != nil {
return err
}
return CopyDirectory(path, container_path)
}
return nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"dtk-container-toolkit/internal/discover"
"dtk-container-toolkit/internal/edits"
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/oci"
"fmt"
"github.com/opencontainers/runtime-spec/specs-go"
)
type discoverModifier struct {
logger logger.Interface
discoverer discover.Discover
}
// NewModifierFromDiscoverer creates a modifier that applies the discovered
// modifications to an OCI spec if required by the runtime wrapper.
func NewModifierFromDiscoverer(logger logger.Interface, d discover.Discover) (oci.SpecModifier, error) {
m := discoverModifier{
logger: logger,
discoverer: d,
}
return &m, nil
}
// Modify applies the modifications required by discoverer to the incomming OCI spec.
// These modifications are applied in-place.
func (m discoverModifier) Modify(spec *specs.Spec) error {
specEdits, err := edits.NewSpecEdits(m.logger, m.discoverer)
if err != nil {
return fmt.Errorf("failed to get required container edits: %v", err)
}
return specEdits.Modify(spec)
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"dtk-container-toolkit/internal/config"
"dtk-container-toolkit/internal/config/image"
"dtk-container-toolkit/internal/discover"
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/lookup/root"
"dtk-container-toolkit/internal/oci"
"fmt"
)
// NewFeatureGatedModifier creates the modifiers for optional features.
// These include:
//
// DTK_MOFED=enabled
//
// If not devices are selected, no changes are made.
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.DTK, driver *root.Driver) (oci.SpecModifier, error) {
if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 {
logger.Infof("No modification required; no devices requested")
return nil, nil
}
var discoverers []discover.Discover
if image.Getenv("DTK_MOFED") == "enabled" {
d, err := discover.NewMOFEDDiscoverer(logger, driver.Root)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err)
}
discoverers = append(discoverers, d)
}
return NewModifierFromDiscoverer(logger, discover.Merge(discoverers...))
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"dtk-container-toolkit/internal/config"
"dtk-container-toolkit/internal/config/image"
"dtk-container-toolkit/internal/discover"
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/lookup"
"dtk-container-toolkit/internal/lookup/root"
"dtk-container-toolkit/internal/oci"
"fmt"
"path/filepath"
"sort"
"strconv"
)
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
// The value of the DTK_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.DTK, driver *root.Driver, isMount bool) (oci.SpecModifier, error) {
dtkCDIHookPath := cfg.DTKCTKConfig.Path
comDiscoverer, err := discover.NewCommonHCUDiscoverer(
logger,
dtkCDIHookPath,
driver,
isMount,
)
if err != nil {
return nil, fmt.Errorf("failed to create mounts discoverer: %v", err)
}
visibleDevices := containerImage.DevicesFromEnvvars(image.EnvVarDTKVisibleDevices, image.EnvVarNvidiaVisibleDevices)
if len(visibleDevices.List()) == 0 {
logger.Info("No devices requested")
return nil, nil
}
busIds, err := getDevicesFromDriver()
if err != nil {
logger.Errorf("No hcu found")
return nil, err
}
err = checkRequestDevices(logger, visibleDevices, busIds)
if err != nil {
return nil, err
}
var selectedBusIds []string
for i, busId := range busIds {
if visibleDevices.Has(fmt.Sprintf("%d", i)) || visibleDevices.Has(busId) {
selectedBusIds = append(selectedBusIds, busId)
}
}
// In standard usage, the devRoot is the same as the driver.Root.
devRoot := driver.Root
drmNodes, err := discover.NewDRMNodesDiscoverer(
logger,
busIds,
selectedBusIds,
devRoot,
)
if err != nil {
return nil, fmt.Errorf("failed to construct discoverer: %v", err)
}
drmByPathLinks := discover.NewCreateDRMByPathSymlinks(logger, drmNodes, devRoot, dtkCDIHookPath)
pciMounts := discover.NewPciMounts(
logger,
lookup.NewDirectoryLocator(
lookup.WithLogger(logger),
lookup.WithCount(1),
lookup.WithSearchPaths("/sys/bus/pci/devices"),
),
driver.Root,
selectedBusIds,
)
d := discover.Merge(
comDiscoverer,
drmNodes,
drmByPathLinks,
pciMounts,
)
return NewModifierFromDiscoverer(logger, d)
}
// getDevicesFromDriver query all HCU devices bus id
func getDevicesFromDriver() ([]string, error) {
var devices []string
matches, err := filepath.Glob("/sys/module/hy*cu/drivers/pci:hy*cu/[0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F]:*")
if err != nil {
return devices, fmt.Errorf("failed to find devices bus id: %v", err)
}
if len(matches) == 0 {
m, err := filepath.Glob("/sys/module/hy*cu/drivers/pci:amdgpu/[0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F]:*")
if err != nil {
return devices, fmt.Errorf("failed to find devices bus id: %v", err)
}
matches = append(matches, m...)
}
for _, path := range sort.StringSlice(matches) {
devices = append(devices, filepath.Base(path))
}
return devices, nil
}
func checkRequestDevices(logger logger.Interface, devices image.VisibleDevices, busIds []string) error {
for _, device := range devices.List() {
if device == "all" || device == "" {
break
}
deviceId, err := strconv.Atoi(device)
if err != nil {
found := false
for _, busId := range busIds {
if device == busId {
found = true
break
}
}
if !found {
return fmt.Errorf("request device %s not found", device)
}
} else if deviceId >= len(busIds) {
logger.Errorf("Request device %s is invalid", device)
return fmt.Errorf("request device %s not found", device)
}
}
return nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/oci"
"path/filepath"
"github.com/opencontainers/runtime-spec/specs-go"
)
func NewNvidiaContainerRuntimeHookRemover(logger logger.Interface) oci.SpecModifier {
m := nvidiaContainerRuntimeHookRemover{
logger: logger,
}
return &m
}
// nvidiaContainerRuntimeHookRemover is a spec modifier that detects and removes inserted nvidia-container-runtime hooks
type nvidiaContainerRuntimeHookRemover struct {
logger logger.Interface
}
var _ oci.SpecModifier = (*nvidiaContainerRuntimeHookRemover)(nil)
// Modify removes any NVIDIA Container Runtime hooks from the provided spec
func (m nvidiaContainerRuntimeHookRemover) Modify(spec *specs.Spec) error {
if spec == nil || spec.Hooks == nil {
return nil
}
if len(spec.Hooks.Prestart) == 0 {
return nil
}
var hooks []specs.Hook
for _, hook := range spec.Hooks.Prestart {
hook := hook
if isNVIDIAContainerRuntimeHook(&hook) {
m.logger.Debugf("Removing hook %v", hook)
continue
}
hooks = append(hooks, hook)
}
if len(hooks) != len(spec.Hooks.Prestart) {
m.logger.Debugf("Updating 'prestart' hooks to %v", hooks)
spec.Hooks.Prestart = hooks
}
return nil
}
// isNVIDIAContainerRuntimeHook checks if the provided hook is an nvidia-container-runtime-hook
// or nvidia-container-toolkit hook. These are included, for example, by the non-experimental
// nvidia-container-runtime or docker when specifying the --gpus flag.
func isNVIDIAContainerRuntimeHook(hook *specs.Hook) bool {
bins := map[string]struct{}{
"nvidia-container-runtime-hook": {},
"nvidia-container-toolkit": {},
}
_, exists := bins[filepath.Base(hook.Path)]
return exists
}
func NewSeccompRemover(logger logger.Interface) oci.SpecModifier {
m := seccompRemover{
logger: logger,
}
return &m
}
// seccompRemover is a spec modifer that disable seccomp
type seccompRemover struct {
logger logger.Interface
}
var _ oci.SpecModifier = (*seccompRemover)(nil)
func (m seccompRemover) Modify(spec *specs.Spec) error {
if spec == nil || spec.Linux == nil {
return nil
}
spec.Linux.Seccomp = nil
spec.Process.ApparmorProfile = ""
m.logger.Info("Remove linux.seccomp in OCI spec and Process.ApparmorProfile")
return nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"dtk-container-toolkit/internal/oci"
"github.com/opencontainers/runtime-spec/specs-go"
)
type List []oci.SpecModifier
// Merge merges a set of OCI specification modifiers as a list.
// This can be used to compose modifiers.
func Merge(modifiers ...oci.SpecModifier) oci.SpecModifier {
var filteredModifiers List
for _, m := range modifiers {
if m == nil {
continue
}
filteredModifiers = append(filteredModifiers, m)
}
return filteredModifiers
}
// Modify applies a list of modifiers in sequence and returns on any errors encountered.
func (m List) Modify(spec *specs.Spec) error {
for _, mm := range m {
if mm == nil {
continue
}
err := mm.Modify(spec)
if err != nil {
return err
}
}
return nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment