Commit d7e13eb9 authored by songlinfeng's avatar songlinfeng
Browse files

add dtk-container-toolkit

parent fcdba4f3
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"dtk-container-toolkit/internal/config"
"dtk-container-toolkit/internal/config/image"
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/oci"
)
// NewStableModifier creates the modifiers for general features.
// These includes:
//
// removes inserted nvidia-container-runtime hooks for --gpus
func NewStableModifier(logger logger.Interface, cfg *config.Config, image image.DTK) (oci.SpecModifier, error) {
var modifiers List
modifiers = append(modifiers, NewNvidiaContainerRuntimeHookRemover(logger))
modifiers = append(modifiers, NewSeccompRemover(logger))
var addCaps []string
// For xprof
addCaps = append(addCaps, "CAP_SYS_RAWIO")
if image.Getenv("DTK_MOFED") == "enabled" {
addCaps = append(addCaps, "CAP_IPC_LOCK")
}
modifiers = append(modifiers, NewCapModifier(logger, addCaps, []string{}))
return modifiers, nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package modifier
import (
"bufio"
"dtk-container-toolkit/internal/config/image"
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/oci"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"github.com/opencontainers/runtime-spec/specs-go"
)
var reDrmRenderMinor = regexp.MustCompile(`drm_render_minor\s(\d+)`)
// sysfsMountModifier is a spec modifier that handle subdirectory mount in /sys
type sysfsMountModifier struct {
logger logger.Interface
devices image.VisibleDevices
busIds []string
dtkCDIHookPath string
}
var _ oci.SpecModifier = (*sysfsMountModifier)(nil)
func NewSysfsMountModifier(logger logger.Interface, devices image.VisibleDevices, busIds []string, dtkCDIHookPath string) oci.SpecModifier {
m := sysfsMountModifier{
logger: logger,
devices: devices,
busIds: busIds,
dtkCDIHookPath: dtkCDIHookPath,
}
return &m
}
func (m sysfsMountModifier) Modify(spec *specs.Spec) error {
if spec == nil {
return nil
}
var selectedBusIds []string
isAll := true
for i, busId := range m.busIds {
if m.devices.Has(fmt.Sprintf("%d", i)) || m.devices.Has(busId) {
selectedBusIds = append(selectedBusIds, busId)
} else {
isAll = false
}
}
if isAll {
m.logger.Debugf("All devices requested, no need to handle /sys mount")
return nil
}
var mounts []specs.Mount
mounted := make(map[string]bool)
for _, mount := range spec.Mounts {
mount := mount
if mount.Destination == "/sys" {
continue
}
mounts = append(mounts, mount)
if strings.HasPrefix(mount.Source, "/sys") {
mounted[mount.Source] = true
}
}
selectRender := make(map[string]bool)
for _, busId := range selectedBusIds {
drmRoot := filepath.Join("/sys/bus/pci/devices", busId, "drm")
renderNodes, err := filepath.Glob(fmt.Sprintf("%s/renderD*", drmRoot))
if err != nil {
return fmt.Errorf("failed to determine DRM render devices for %v: %v", busId, err)
}
for _, renderNode := range renderNodes {
selectRender[filepath.Base(renderNode)] = true
}
}
nodeRoot := "/sys/devices/virtual/kfd/kfd/topology/nodes"
matches, err := filepath.Glob(fmt.Sprintf("%s/*", nodeRoot))
if err != nil {
m.logger.Warningf("Failed to found topology nodes")
return err
}
for _, path := range matches {
render_minor, err := ParseTopologyProperties(filepath.Join(path, "properties"), reDrmRenderMinor)
if err != nil {
return err
}
if int(render_minor) == 0 || selectRender[fmt.Sprintf("renderD%d", int(render_minor))] {
mounts = append(mounts, specs.Mount{
Destination: path,
Type: "bind",
Source: path,
Options: []string{"rbind", "rprivate"},
})
}
}
var links []string
curPath := filepath.Dir(nodeRoot)
base := filepath.Base(nodeRoot)
for {
matches, err := filepath.Glob(fmt.Sprintf("%s/*", curPath))
if err != nil {
m.logger.Warningf("failed to find subdirecties for %s: %v", curPath, err)
return nil
}
for _, path := range matches {
if filepath.Base(path) == base || mounted[path] {
continue
}
lpath, err := os.Readlink(path)
if err != nil {
mounts = append(mounts, specs.Mount{
Destination: path,
Type: "bind",
Source: path,
Options: []string{"rbind", "rprivate"},
})
} else {
m.logger.Debugf("adding symlink %v -> %v", path, lpath)
links = append(links, fmt.Sprintf("%v::%v", lpath, path))
}
}
base = filepath.Base(curPath)
curPath = filepath.Dir(curPath)
if curPath == "/" {
break
}
}
spec.Mounts = mounts
if len(links) != 0 {
var args []string
args = append(args, "dtk-ctk", "hook", "create-symlinks")
for _, l := range links {
args = append(args, "--link", l)
}
var hooks []specs.Hook
for _, hook := range spec.Hooks.CreateContainer {
hook := hook
hooks = append(hooks, hook)
}
hooks = append(hooks, specs.Hook{
Path: m.dtkCDIHookPath,
Args: args,
})
spec.Hooks.CreateContainer = hooks
}
return nil
}
// ParseTopologyProperties parse for a property value in kfd topology file
// The format is usually one entry per line <name> <value>.
func ParseTopologyProperties(path string, re *regexp.Regexp) (int64, error) {
f, e := os.Open(path)
if e != nil {
return 0, e
}
e = errors.New("Topology property not found. Regex: " + re.String())
v := int64(0)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
m := re.FindStringSubmatch(scanner.Text())
if m == nil {
continue
}
v, e = strconv.ParseInt(m[1], 0, 64)
break
}
f.Close()
return v, e
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
import (
"fmt"
"path/filepath"
"strings"
)
const (
specFileName = "config.json"
)
// GetBundleDir returns the bundle directory or default depending on the
// supplied command line arguments.
func GetBundleDir(args []string) (string, error) {
bundleDir, err := GetBundleDirFromArgs(args)
if err != nil {
return "", fmt.Errorf("error getting bundle dir from args: %v", err)
}
return bundleDir, nil
}
// GetBundleDirFromArgs checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc.
// The following are supported:
// --bundle{{SEP}}BUNDLE_PATH
// -bundle{{SEP}}BUNDLE_PATH
// -b{{SEP}}BUNDLE_PATH
// where {{SEP}} is either ' ' or '='
func GetBundleDirFromArgs(args []string) (string, error) {
var bundleDir string
for i := 0; i < len(args); i++ {
param := args[i]
parts := strings.SplitN(param, "=", 2)
if !IsBundleFlag(parts[0]) {
continue
}
// The flag has the format --bundle=/path
if len(parts) == 2 {
bundleDir = parts[1]
continue
}
// The flag has the format --bundle /path
if i+1 < len(args) {
bundleDir = args[i+1]
i++
continue
}
// --bundle / -b was the last element of args
return "", fmt.Errorf("bundle option requires an argument")
}
return bundleDir, nil
}
// GetSpecFilePath returns the expected path to the OCI specification file for the given
// bundle directory.
func GetSpecFilePath(bundleDir string) string {
specFilePath := filepath.Join(bundleDir, specFileName)
return specFilePath
}
// IsBundleFlag is a helper function that checks wither the specified argument represents
// a bundle flag (--bundle or -b)
func IsBundleFlag(arg string) bool {
if !strings.HasPrefix(arg, "-") {
return false
}
trimmed := strings.TrimLeft(arg, "-")
return trimmed == "b" || trimmed == "bundle"
}
// HasCreateSubcommand checks the supplied arguments for a 'create' subcommand
func HasCreateSubcommand(args []string) bool {
var previousWasBundle bool
for _, a := range args {
// We check for '--bundle create' explicitly to ensure that we
// don't inadvertently trigger a modification if the bundle directory
// is specified as `create
if !previousWasBundle && IsBundleFlag(a) {
previousWasBundle = true
continue
}
if !previousWasBundle && a == "create" {
return true
}
previousWasBundle = false
}
return false
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
//go:generate moq -stub -out runtime_mock.go . Runtime
// Runtime is an interface for a runtime shim. The Exec method accepts a list
// of command line arguments, and returns an error / nil.
type Runtime interface {
Exec([]string) error
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
import (
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/internal/lookup"
"fmt"
)
// NewLowLevelRuntime creates a Runtime that wraps a low-level runtime executable.
// The executable specified is taken from the list of supplied candidates, with the first match
// present in the PATH being selected. A logger is also specified.
func NewLowLevelRuntime(logger logger.Interface, candidates []string) (Runtime, error) {
runtimePath, err := findRuntime(logger, candidates)
if err != nil {
return nil, fmt.Errorf("error locating runtime: %v", err)
}
logger.Infof("using low-level runtime %v", runtimePath)
return NewRuntimeForPath(logger, runtimePath)
}
// findRuntime checks elements in a list of supplied candidates for a matching executable in the PATH.
// The absolute path to the first match is returned.
func findRuntime(logger logger.Interface, candidates []string) (string, error) {
if len(candidates) == 0 {
return "", fmt.Errorf("at least one runtime candidate must be specified")
}
locator := lookup.NewExecutableLocator(logger, "/")
for _, candidate := range candidates {
logger.Debugf("Looking for runtime library '%v'", candidate)
targets, err := locator.Locate(candidate)
if err == nil && len(targets) > 0 {
logger.Debugf("Found runtime binary '%v'", targets)
return targets[0], nil
}
logger.Debugf("Runtime binary '%v' not found: %v (targets=%v)", candidate, err, targets)
}
return "", fmt.Errorf("no runtime binary found from candidate list: %v", candidates)
}
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package oci
import (
"sync"
)
// Ensure, that RuntimeMock does implement Runtime.
// If this is not the case, regenerate this file with moq.
var _ Runtime = &RuntimeMock{}
// RuntimeMock is a mock implementation of Runtime.
//
// func TestSomethingThatUsesRuntime(t *testing.T) {
//
// // make and configure a mocked Runtime
// mockedRuntime := &RuntimeMock{
// ExecFunc: func(strings []string) error {
// panic("mock out the Exec method")
// },
// }
//
// // use mockedRuntime in code that requires Runtime
// // and then make assertions.
//
// }
type RuntimeMock struct {
// ExecFunc mocks the Exec method.
ExecFunc func(strings []string) error
// calls tracks calls to the methods.
calls struct {
// Exec holds details about calls to the Exec method.
Exec []struct {
// Strings is the strings argument value.
Strings []string
}
}
lockExec sync.RWMutex
}
// Exec calls ExecFunc.
func (mock *RuntimeMock) Exec(strings []string) error {
callInfo := struct {
Strings []string
}{
Strings: strings,
}
mock.lockExec.Lock()
mock.calls.Exec = append(mock.calls.Exec, callInfo)
mock.lockExec.Unlock()
if mock.ExecFunc == nil {
var (
errOut error
)
return errOut
}
return mock.ExecFunc(strings)
}
// ExecCalls gets all the calls that were made to Exec.
// Check the length with:
//
// len(mockedRuntime.ExecCalls())
func (mock *RuntimeMock) ExecCalls() []struct {
Strings []string
} {
var calls []struct {
Strings []string
}
mock.lockExec.RLock()
calls = mock.calls.Exec
mock.lockExec.RUnlock()
return calls
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
import (
"dtk-container-toolkit/internal/logger"
"fmt"
)
type modifyingRuntimeWrapper struct {
logger logger.Interface
runtime Runtime
ociSpec Spec
modifier SpecModifier
}
var _ Runtime = (*modifyingRuntimeWrapper)(nil)
// NewModifyingRuntimeWrapper creates a runtime wrapper that applies the specified modifier to the OCI specification
// before invoking the wrapped runtime. If the modifier is nil, the input runtime is returned.
func NewModifyingRuntimeWrapper(logger logger.Interface, runtime Runtime, spec Spec, modifier SpecModifier) Runtime {
if modifier == nil {
logger.Infof("Using low-level runtime with no modification")
return runtime
}
rt := modifyingRuntimeWrapper{
logger: logger,
runtime: runtime,
ociSpec: spec,
modifier: modifier,
}
return &rt
}
// Exec checks whether a modification of the OCI specification is required and modifies it accordingly before exec-ing
// into the wrapped runtime.
func (r *modifyingRuntimeWrapper) Exec(args []string) error {
if HasCreateSubcommand(args) {
err := r.modify()
if err != nil {
return fmt.Errorf("could not apply required modification to OCI specification: %v", err)
}
r.logger.Infof("Applied required modification to OCI specification")
} else {
r.logger.Infof("No modification of OCI specification required")
}
r.logger.Infof("Forwarding command to runtime")
return r.runtime.Exec(args)
}
// modify loads, modifies, and flushes the OCI specification using the defined Modifier
func (r *modifyingRuntimeWrapper) modify() error {
_, err := r.ociSpec.Load()
if err != nil {
return fmt.Errorf("error loading OCI specification for modification: %v", err)
}
err = r.ociSpec.Modify(r.modifier)
if err != nil {
return fmt.Errorf("error modifying OCI spec: %v", err)
}
err = r.ociSpec.Flush()
if err != nil {
return fmt.Errorf("error writing modified OCI specification: %v", err)
}
return nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
import (
"dtk-container-toolkit/internal/logger"
"fmt"
"os"
)
// pathRuntime wraps the path that a binary and defines the semantics for how to exec into it.
// This can be used to wrap an OCI-compliant low-level runtime binary, allowing it to be used through the
// Runtime internface.
type pathRuntime struct {
logger logger.Interface
path string
execRuntime Runtime
}
var _ Runtime = (*pathRuntime)(nil)
// NewRuntimeForPath creates a Runtime for the specified logger and path
func NewRuntimeForPath(logger logger.Interface, path string) (Runtime, error) {
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("invalid path '%v': %v", path, err)
}
if info.IsDir() || info.Mode()&0111 == 0 {
return nil, fmt.Errorf("specified path '%v' is not an executable file", path)
}
shim := pathRuntime{
logger: logger,
path: path,
execRuntime: syscallExec{},
}
return &shim, nil
}
// Exec exces into the binary at the path from the pathRuntime struct, passing it the supplied arguments
// after ensuring that the first argument is the path of the target binary.
func (s pathRuntime) Exec(args []string) error {
runtimeArgs := []string{s.path}
if len(args) > 1 {
runtimeArgs = append(runtimeArgs, args[1:]...)
}
return s.execRuntime.Exec(runtimeArgs)
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
import (
"fmt"
"os"
"syscall"
)
type syscallExec struct{}
var _ Runtime = (*syscallExec)(nil)
func (r syscallExec) Exec(args []string) error {
//nolint:gosec // TODO: Can we harden this so that there is less risk of command injection
err := syscall.Exec(args[0], args, os.Environ())
if err != nil {
return fmt.Errorf("cound not exec '%v': %v", args[0], err)
}
// syscall.Exec is not expected to return. This is an error state regardless of whether
// err is nil or not.
return fmt.Errorf("unexpected return from exec '%v'", args[0])
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
import (
"dtk-container-toolkit/internal/logger"
"fmt"
"github.com/opencontainers/runtime-spec/specs-go"
)
// SpecModifier defines an interface for modifying a (raw) OCI spec
type SpecModifier interface {
// Modify is a method that accepts a pointer to an OCI Srec and returns an
// error. The intention is that the function would modify the spec in-place.
Modify(*specs.Spec) error
}
// Spec defines the operations to be performed on an OCI specification
//
//go:generate moq -stub -out spec_mock.go . Spec
type Spec interface {
Load() (*specs.Spec, error)
Flush() error
Modify(SpecModifier) error
LookupEnv(string) (string, bool)
}
// NewSpec creates fileSpec based on the command line arguments passed to the
// application using the specified logger.
func NewSpec(logger logger.Interface, args []string) (Spec, error) {
bundleDir, err := GetBundleDir(args)
if err != nil {
return nil, fmt.Errorf("error getting bundle directory: %v", err)
}
logger.Debugf("Using bundle directory: %v", bundleDir)
ociSpecPath := GetSpecFilePath(bundleDir)
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
ociSpec := NewFileSpec(ociSpecPath)
return ociSpec, nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
import (
"encoding/json"
"fmt"
"io"
"os"
"github.com/opencontainers/runtime-spec/specs-go"
)
type fileSpec struct {
memorySpec
path string
}
var _ Spec = (*fileSpec)(nil)
// NewFileSpec creates an object that encapsulates a file-backed OCI spec.
// This can be used to read from the file, modify the spec, and write to the
// same file.
func NewFileSpec(filepath string) Spec {
oci := fileSpec{
path: filepath,
}
return &oci
}
// Load reads the contents of an OCI spec from file to be referenced internally.
// The file is opened "read-only"
func (s *fileSpec) Load() (*specs.Spec, error) {
specFile, err := os.Open(s.path)
if err != nil {
return nil, fmt.Errorf("error opening OCI specification file: %v", err)
}
defer specFile.Close()
spec, err := LoadFrom(specFile)
if err != nil {
return nil, fmt.Errorf("error loading OCI specification from file: %v", err)
}
s.Spec = spec
return s.Spec, nil
}
// LoadFrom reads the contents of the OCI spec from the specified io.Reader.
func LoadFrom(reader io.Reader) (*specs.Spec, error) {
decoder := json.NewDecoder(reader)
var spec specs.Spec
err := decoder.Decode(&spec)
if err != nil {
return nil, fmt.Errorf("error reading OCI specification: %v", err)
}
return &spec, nil
}
// Modify applies the specified SpecModifier to the stored OCI specification.
func (s *fileSpec) Modify(m SpecModifier) error {
return s.memorySpec.Modify(m)
}
// Flush writes the stored OCI specification to the filepath specified by the path member.
// The file is truncated upon opening, overwriting any existing contents.
func (s fileSpec) Flush() error {
if s.Spec == nil {
return fmt.Errorf("no OCI specification loaded")
}
specFile, err := os.Create(s.path)
if err != nil {
return fmt.Errorf("error opening OCI specification file: %v", err)
}
defer specFile.Close()
return flushTo(s.Spec, specFile)
}
// flushTo writes the stored OCI specification to the specified io.Writer.
func flushTo(spec *specs.Spec, writer io.Writer) error {
if spec == nil {
return nil
}
encoder := json.NewEncoder(writer)
err := encoder.Encode(spec)
if err != nil {
return fmt.Errorf("error writing OCI specification: %v", err)
}
return nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
import (
"fmt"
"strings"
"github.com/opencontainers/runtime-spec/specs-go"
)
type memorySpec struct {
*specs.Spec
}
// NewMemorySpec creates a Spec instance from the specified OCI spec
func NewMemorySpec(spec *specs.Spec) Spec {
s := memorySpec{
Spec: spec,
}
return &s
}
// Load is a no-op for the memorySpec spec
func (s *memorySpec) Load() (*specs.Spec, error) {
return s.Spec, nil
}
// Flush is a no-op for the memorySpec spec
func (s *memorySpec) Flush() error {
return nil
}
// Modify applies the specified SpecModifier to the stored OCI specification.
func (s *memorySpec) Modify(m SpecModifier) error {
if s.Spec == nil {
return fmt.Errorf("cannot modify nil spec")
}
return m.Modify(s.Spec)
}
// LookupEnv mirrors os.LookupEnv for the OCI specification. It
// retrieves the value of the environment variable named
// by the key. If the variable is present in the environment the
// value (which may be empty) is returned and the boolean is true.
// Otherwise the returned value will be empty and the boolean will
// be false.
func (s memorySpec) LookupEnv(key string) (string, bool) {
if s.Spec == nil || s.Spec.Process == nil {
return "", false
}
for _, env := range s.Spec.Process.Env {
if !strings.HasPrefix(env, key) {
continue
}
parts := strings.SplitN(env, "=", 2)
if parts[0] == key {
if len(parts) < 2 {
return "", true
}
return parts[1], true
}
}
return "", false
}
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package oci
import (
"github.com/opencontainers/runtime-spec/specs-go"
"sync"
)
// Ensure, that SpecMock does implement Spec.
// If this is not the case, regenerate this file with moq.
var _ Spec = &SpecMock{}
// SpecMock is a mock implementation of Spec.
//
// func TestSomethingThatUsesSpec(t *testing.T) {
//
// // make and configure a mocked Spec
// mockedSpec := &SpecMock{
// FlushFunc: func() error {
// panic("mock out the Flush method")
// },
// LoadFunc: func() (*specs.Spec, error) {
// panic("mock out the Load method")
// },
// LookupEnvFunc: func(s string) (string, bool) {
// panic("mock out the LookupEnv method")
// },
// ModifyFunc: func(specModifier SpecModifier) error {
// panic("mock out the Modify method")
// },
// }
//
// // use mockedSpec in code that requires Spec
// // and then make assertions.
//
// }
type SpecMock struct {
// FlushFunc mocks the Flush method.
FlushFunc func() error
// LoadFunc mocks the Load method.
LoadFunc func() (*specs.Spec, error)
// LookupEnvFunc mocks the LookupEnv method.
LookupEnvFunc func(s string) (string, bool)
// ModifyFunc mocks the Modify method.
ModifyFunc func(specModifier SpecModifier) error
// calls tracks calls to the methods.
calls struct {
// Flush holds details about calls to the Flush method.
Flush []struct {
}
// Load holds details about calls to the Load method.
Load []struct {
}
// LookupEnv holds details about calls to the LookupEnv method.
LookupEnv []struct {
// S is the s argument value.
S string
}
// Modify holds details about calls to the Modify method.
Modify []struct {
// SpecModifier is the specModifier argument value.
SpecModifier SpecModifier
}
}
lockFlush sync.RWMutex
lockLoad sync.RWMutex
lockLookupEnv sync.RWMutex
lockModify sync.RWMutex
}
// Flush calls FlushFunc.
func (mock *SpecMock) Flush() error {
callInfo := struct {
}{}
mock.lockFlush.Lock()
mock.calls.Flush = append(mock.calls.Flush, callInfo)
mock.lockFlush.Unlock()
if mock.FlushFunc == nil {
var (
errOut error
)
return errOut
}
return mock.FlushFunc()
}
// FlushCalls gets all the calls that were made to Flush.
// Check the length with:
//
// len(mockedSpec.FlushCalls())
func (mock *SpecMock) FlushCalls() []struct {
} {
var calls []struct {
}
mock.lockFlush.RLock()
calls = mock.calls.Flush
mock.lockFlush.RUnlock()
return calls
}
// Load calls LoadFunc.
func (mock *SpecMock) Load() (*specs.Spec, error) {
callInfo := struct {
}{}
mock.lockLoad.Lock()
mock.calls.Load = append(mock.calls.Load, callInfo)
mock.lockLoad.Unlock()
if mock.LoadFunc == nil {
var (
specOut *specs.Spec
errOut error
)
return specOut, errOut
}
return mock.LoadFunc()
}
// LoadCalls gets all the calls that were made to Load.
// Check the length with:
//
// len(mockedSpec.LoadCalls())
func (mock *SpecMock) LoadCalls() []struct {
} {
var calls []struct {
}
mock.lockLoad.RLock()
calls = mock.calls.Load
mock.lockLoad.RUnlock()
return calls
}
// LookupEnv calls LookupEnvFunc.
func (mock *SpecMock) LookupEnv(s string) (string, bool) {
callInfo := struct {
S string
}{
S: s,
}
mock.lockLookupEnv.Lock()
mock.calls.LookupEnv = append(mock.calls.LookupEnv, callInfo)
mock.lockLookupEnv.Unlock()
if mock.LookupEnvFunc == nil {
var (
sOut string
bOut bool
)
return sOut, bOut
}
return mock.LookupEnvFunc(s)
}
// LookupEnvCalls gets all the calls that were made to LookupEnv.
// Check the length with:
//
// len(mockedSpec.LookupEnvCalls())
func (mock *SpecMock) LookupEnvCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLookupEnv.RLock()
calls = mock.calls.LookupEnv
mock.lockLookupEnv.RUnlock()
return calls
}
// Modify calls ModifyFunc.
func (mock *SpecMock) Modify(specModifier SpecModifier) error {
callInfo := struct {
SpecModifier SpecModifier
}{
SpecModifier: specModifier,
}
mock.lockModify.Lock()
mock.calls.Modify = append(mock.calls.Modify, callInfo)
mock.lockModify.Unlock()
if mock.ModifyFunc == nil {
var (
errOut error
)
return errOut
}
return mock.ModifyFunc(specModifier)
}
// ModifyCalls gets all the calls that were made to Modify.
// Check the length with:
//
// len(mockedSpec.ModifyCalls())
func (mock *SpecMock) ModifyCalls() []struct {
SpecModifier SpecModifier
} {
var calls []struct {
SpecModifier SpecModifier
}
mock.lockModify.RLock()
calls = mock.calls.Modify
mock.lockModify.RUnlock()
return calls
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package oci
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"github.com/opencontainers/runtime-spec/specs-go"
)
// State stores an OCI container state. This includes the spec path and the environment
type State specs.State
// LoadContainerState loads the container state from the specified filename. If the filename is empty or '-' the state is loaded from STDIN
func LoadContainerState(filename string) (*State, error) {
if filename == "" || filename == "-" {
return ReadContainerState(os.Stdin)
}
inputFile, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("failed to open file: %v", err)
}
defer inputFile.Close()
return ReadContainerState(inputFile)
}
// ReadContainerState reads the container state from the specified reader
func ReadContainerState(reader io.Reader) (*State, error) {
var s State
d := json.NewDecoder(reader)
if err := d.Decode(&s); err != nil {
return nil, fmt.Errorf("failed to decode container state: %v", err)
}
return &s, nil
}
// LoadSpec loads the OCI spec associated with the container state
func (s *State) LoadSpec() (*specs.Spec, error) {
specFilePath := GetSpecFilePath(s.Bundle)
specFile, err := os.Open(specFilePath)
if err != nil {
return nil, fmt.Errorf("failed to open OCI spec file: %v", err)
}
defer specFile.Close()
spec, err := LoadFrom(specFile)
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
}
return spec, nil
}
// GetContainerRoot returns the root for the container from the associated spec. If the spec is not yet loaded, it is
// loaded and cached.
func (s *State) GetContainerRoot() (string, error) {
spec, err := s.LoadSpec()
if err != nil {
return "", err
}
var containerRoot string
if spec.Root != nil {
containerRoot = spec.Root.Path
}
if filepath.IsAbs(containerRoot) {
return containerRoot, nil
}
return filepath.Join(s.Bundle, containerRoot), nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package dhcu
import (
"dtk-container-toolkit/internal/discover"
"dtk-container-toolkit/internal/info/drm"
"dtk-container-toolkit/internal/lookup"
"dtk-container-toolkit/pkg/go-c3000lib/pkg/device"
"fmt"
)
func (o *options) newC3000SimDHCUDiscoverer(d device.Device) (discover.Discover, error) {
pciBusID := d.GetPCIBusID()
drmDeviceNodes, err := drm.GetDeviceNodesByBusID(pciBusID)
if err != nil {
return nil, fmt.Errorf("failed to determine DRM devices for %v: %v", pciBusID, err)
}
deviceNodes := discover.NewCharDeviceDiscoverer(
o.logger,
o.devRoot,
drmDeviceNodes,
)
byPathHooks := discover.NewCreateDRMByPathSymlinks(o.logger, deviceNodes, o.devRoot, o.dtkCDIHookPath)
pciMounts := discover.NewPciMounts(
o.logger,
lookup.NewDirectoryLocator(
lookup.WithLogger(o.logger),
lookup.WithCount(1),
lookup.WithSearchPaths("/sys/bus/pci/devices"),
),
o.devRoot,
[]string{pciBusID},
)
dd := discover.Merge(
deviceNodes,
byPathHooks,
pciMounts,
)
return dd, nil
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package dhcu
import (
"dtk-container-toolkit/internal/discover"
"dtk-container-toolkit/internal/logger"
"dtk-container-toolkit/pkg/go-c3000lib/pkg/device"
"errors"
)
// NewForDevice creates a discoverer for the specified Device.
func NewForDevice(d device.Device, opts ...Option) (discover.Discover, error) {
o := new(opts...)
var discoverers []discover.Discover
var errs error
c3000smiDiscoverer, err := o.newC3000SimDHCUDiscoverer(d)
if err != nil {
errs = errors.Join(errs, err)
} else if c3000smiDiscoverer != nil {
discoverers = append(discoverers, c3000smiDiscoverer)
}
if len(discoverers) == 0 {
return nil, errs
}
return discover.WithCache(
discover.FirstValid(
discoverers...,
),
), nil
}
func new(opts ...Option) *options {
o := &options{}
for _, opt := range opts {
opt(o)
}
if o.logger == nil {
o.logger = logger.New()
}
return o
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package dhcu
import (
"dtk-container-toolkit/internal/c3000caps"
"dtk-container-toolkit/internal/logger"
)
type options struct {
logger logger.Interface
devRoot string
dtkCDIHookPath string
isMigDevice bool
// migCaps stores the MIG capabilities for the system.
// If MIG is not available, this is nil.
migCaps c3000caps.MigCaps
migCapsError error
}
type Option func(*options)
// WithDevRoot sets the root where /dev is located.
func WithDevRoot(root string) Option {
return func(l *options) {
l.devRoot = root
}
}
// WithLogger sets the logger for the library
func WithLogger(logger logger.Interface) Option {
return func(l *options) {
l.logger = logger
}
}
// WithDTKCDIHookPath sets the path to the DTK Container Toolkit CLI path for the library
func WithDTKCDIHookPath(path string) Option {
return func(l *options) {
l.dtkCDIHookPath = path
}
}
// WithMIGCaps sets the MIG capabilities.
func WithMIGCaps(migCaps c3000caps.MigCaps) Option {
return func(l *options) {
l.migCaps = migCaps
}
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package runtime
type rt struct {
logger *Logger
modeOverride string
}
// Interface is the interface for the runtime library.
type Interface interface {
Run([]string) error
}
// Option is a function that configures the runtime.
type Option func(*rt)
// New creates a runtime with the specified options.
func New(opts ...Option) Interface {
r := rt{}
for _, opt := range opts {
opt(&r)
}
if r.logger == nil {
r.logger = NewLogger()
}
return &r
}
// WithModeOverride allows for overriding the mode specified in the config.
func WithModeOverride(mode string) Option {
return func(r *rt) {
r.modeOverride = mode
}
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package runtime
import (
"dtk-container-toolkit/internal/logger"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"github.com/sirupsen/logrus"
)
// Logger adds a way to manage output to a log file to a logrus.Logger
type Logger struct {
logger.Interface
previousLogger logger.Interface
logFiles []*os.File
}
// NewLogger creates an empty logger
func NewLogger() *Logger {
return &Logger{
Interface: logrus.New(),
}
}
// Update constructs a Logger with a preddefined formatter
func (l *Logger) Update(filename string, logLevel string, argv []string) {
configFromArgs := parseArgs(argv)
level, logLevelError := configFromArgs.getLevel(logLevel)
defer func() {
if logLevelError != nil {
l.Warning(logLevelError)
}
}()
var logFiles []*os.File
var argLogFileError error
// We don't create log files if the version argument is supplied
if !configFromArgs.version {
configLogFile, err := createLogFile(filename)
if err != nil {
argLogFileError = errors.Join(argLogFileError, err)
}
if configLogFile != nil {
logFiles = append(logFiles, configLogFile)
}
argLogFile, err := createLogFile(configFromArgs.file)
if argLogFile != nil {
logFiles = append(logFiles, argLogFile)
}
argLogFileError = errors.Join(argLogFileError, err)
}
defer func() {
if argLogFileError != nil {
l.Warningf("Failed to open log file :%v", argLogFileError)
}
}()
newLogger := logrus.New()
newLogger.SetLevel(level)
if level == logrus.DebugLevel {
logrus.SetReportCaller(true)
// Shorten function and file names reported by the logger, by
// trimming common "github.com/opencontainers/runc" prefix.
// This is only done for text formatter.
_, file, _, _ := runtime.Caller(0)
prefix := filepath.Dir(file) + "/"
logrus.SetFormatter(&logrus.TextFormatter{
CallerPrettyfier: func(f *runtime.Frame) (string, string) {
function := strings.TrimPrefix(f.Function, prefix) + "()"
fileLine := strings.TrimPrefix(f.File, prefix) + ":" + strconv.Itoa(f.Line)
return function, fileLine
},
})
}
if configFromArgs.format == "json" {
newLogger.SetFormatter(new(logrus.JSONFormatter))
}
switch len(logFiles) {
case 0:
newLogger.SetOutput(io.Discard)
case 1:
newLogger.SetOutput(logFiles[0])
default:
var writers []io.Writer
for _, f := range logFiles {
writers = append(writers, f)
}
newLogger.SetOutput(io.MultiWriter(writers...))
}
*l = Logger{
Interface: newLogger,
previousLogger: l.Interface,
logFiles: logFiles,
}
}
// Reset closes the log file (if any) and resets the logger output to what it
// was before UpdateLogger was called.
func (l *Logger) Reset() error {
defer func() {
previous := l.previousLogger
if previous == nil {
previous = logger.New()
}
l.Interface = previous
l.previousLogger = nil
l.logFiles = nil
}()
var errs []error
for _, f := range l.logFiles {
err := f.Close()
if err != nil {
errs = append(errs, err)
}
}
var err error
for _, e := range errs {
if err == nil {
err = e
continue
}
return fmt.Errorf("%v: %w", e, err)
}
return err
}
func createLogFile(filename string) (*os.File, error) {
if filename == "" || filename == os.DevNull {
return nil, nil
}
if dir := filepath.Dir(filepath.Clean(filename)); dir != "." {
err := os.MkdirAll(dir, 0755)
if err != nil {
return nil, err
}
}
return os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
}
type loggerConfig struct {
file string
format string
debug bool
version bool
}
func (c loggerConfig) getLevel(logLevel string) (logrus.Level, error) {
if c.debug {
return logrus.DebugLevel, nil
}
if logLevel, err := logrus.ParseLevel(logLevel); err == nil {
return logLevel, nil
}
return logrus.InfoLevel, fmt.Errorf("invalid log-level '%v'", logLevel)
}
// Informed by Taken from https://github.com/opencontainers/runc/blob/7fd8b57001f5bfa102e89cb434d96bf71f7c1d35/main.go#L182
func parseArgs(args []string) loggerConfig {
c := loggerConfig{}
expected := map[string]*string{
"log-format": &c.format,
"log": &c.file,
}
found := make(map[string]bool)
for i := 0; i < len(args); i++ {
if len(found) == 4 {
break
}
param := args[i]
parts := strings.SplitN(param, "=", 2)
trimmed := strings.TrimLeft(parts[0], "-")
// If this is not a flag we continue
if parts[0] == trimmed {
continue
}
// Check the version flag
if trimmed == "version" {
c.version = true
found["version"] = true
// For the version flag we don't process any other flags
continue
}
// Check the debug flag
if trimmed == "debug" {
c.debug = true
found["debug"] = true
continue
}
destination, exists := expected[trimmed]
if !exists {
continue
}
var value string
switch {
case len(parts) == 2:
value = parts[2]
case i+1 < len(args):
value = args[i+1]
i++
default:
continue
}
*destination = value
found[trimmed] = true
}
return c
}
/**
# Copyright (c) 2024, HCUOpt CORPORATION. All rights reserved.
**/
package runtime
import (
"dtk-container-toolkit/internal/config"
"dtk-container-toolkit/internal/info"
"dtk-container-toolkit/internal/lookup/root"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/opencontainers/runtime-spec/specs-go"
)
// Run is an entry point that allows for idiomatic handling of errors
// when calling from the main function.
func (r rt) Run(argv []string) (rerr error) {
defer func() {
if rerr != nil {
r.logger.Errorf("%v", rerr)
}
}()
printVersion := hasVersionFlag(argv)
if printVersion {
fmt.Printf("%v version %v\n", "DTK Container Runtime", info.GetVersionString(fmt.Sprintf("spec: %v", specs.Version)))
}
cfg, err := config.GetConfig()
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
r.logger.Update(
cfg.DTKContainerRuntimeConfig.DebugFilePath,
cfg.DTKContainerRuntimeConfig.LogLevel,
argv,
)
defer func() {
if rerr != nil {
r.logger.Errorf("%v", rerr)
}
if err := r.logger.Reset(); err != nil {
rerr = errors.Join(rerr, fmt.Errorf("failed to reset logger: %v", err))
}
}()
//nolint:staticcheck // TODO(elezar): We should switch the dtk-container-runtime from using dtk-ctk to using dtk-cdi-hook.
cfg.DTKCTKConfig.Path = config.ResolveDTKCDIHookPath(r.logger, cfg.DTKCTKConfig.Path)
// Print the config to the output.
configJSON, err := json.MarshalIndent(cfg, "", " ")
if err == nil {
r.logger.Debugf("Running with config:\n%v", string(configJSON))
} else {
r.logger.Debugf("Running with config:\n%+v", cfg)
}
driver := root.New(
root.WithLogger(r.logger),
root.WithDriverRoot(""),
)
r.logger.Infof("Command line arguments: %v", argv)
runtime, err := newDTKContainerRuntime(r.logger, cfg, argv, driver)
if err != nil {
return fmt.Errorf("failed to create DTK Container Runtime: %v", err)
}
if printVersion {
fmt.Print("\n")
}
return runtime.Exec(argv)
}
func (r rt) Errorf(format string, args ...interface{}) {
r.logger.Errorf(format, args...)
}
// TODO: This should be refactored / combined with parseArgs in logger.
func hasVersionFlag(args []string) bool {
for i := 0; i < len(args); i++ {
param := args[i]
parts := strings.SplitN(param, "=", 2)
trimmed := strings.TrimLeft(parts[0], "-")
// If this is not a flag we continue
if parts[0] == trimmed {
continue
}
// Check the version flag
if trimmed == "version" {
return true
}
}
return false
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment