/**
# Copyright (c) 2024, HCUOpt CORPORATION.  All rights reserved.
**/

package discover

import (
	"dcu-container-toolkit/internal/config/image"
	"dcu-container-toolkit/internal/info/drm"
	"dcu-container-toolkit/internal/logger"
	"dcu-container-toolkit/internal/lookup"
	"dcu-container-toolkit/internal/lookup/root"
	"fmt"
	"os"
)

// NewDRMNodesDiscoverer returns a discoverer for the DRM device nodes associated with the specified visible devices.
//
// TODO: The logic for creating DRM devices should be consolidated between this
// and the logic for generating CDI specs for a single device. This is only used
// when applying OCI spec modifications to an incoming spec in "legacy" mode.
func NewDRMNodesDiscoverer(logger logger.Interface, busIds []string, requestBusIds []string, devRoot string) (Discover, error) {
	drmDeviceNodes, err := newDRMDeviceDiscoverer(logger, busIds, requestBusIds, devRoot)
	if err != nil {
		return nil, fmt.Errorf("failed to create DRM device discoverer: %v", err)
	}

	return drmDeviceNodes, nil
}

// newDRMDeviceDiscoverer creates a discoverer for the DRM devices associated with the requested devices.
func newDRMDeviceDiscoverer(logger logger.Interface, busIds []string, requestBusIds []string, devRoot string) (Discover, error) {
	allDevices := NewCharDeviceDiscoverer(
		logger,
		devRoot,
		[]string{
			"/dev/dri/card*",
			"/dev/dri/renderD*",
		},
	)

	filter := make(selectDeviceByPath)
	for _, busId := range requestBusIds {
		drmDeviceNodes, err := drm.GetDeviceNodesByBusID(busId)
		if err != nil {
			return nil, fmt.Errorf("failed to determine DRM devices for %v: %v", busId, err)
		}
		logger.Infof("selected drm nodes from bus %s: %v", busId, drmDeviceNodes)
		for _, drmDeviceNode := range drmDeviceNodes {
			filter[drmDeviceNode] = true
		}
	}

	// We return a discoverer that applies the DRM device filter created above to all discovered DRM device nodes.
	d := newFilteredDiscoverer(
		logger,
		allDevices,
		filter,
	)

	return d, nil
}

// selectDeviceByPath is a filter that allows devices to be selected by the path
type selectDeviceByPath map[string]bool

var _ Filter = (*selectDeviceByPath)(nil)

// DeviceIsSelected determines whether the device's path has been selected
func (s selectDeviceByPath) DeviceIsSelected(device Device) bool {
	return s[device.Path]
}

// MountIsSelected is always true
func (s selectDeviceByPath) MountIsSelected(Mount) bool {
	return true
}

// HookIsSelected is always true
func (s selectDeviceByPath) HookIsSelected(Hook) bool {
	return true
}

// NewCommonHCUDiscoverer creates a discoverer for the mounts required by HCU.
func NewCommonHCUDiscoverer(logger logger.Interface, dtkCDIHookPath string, driver *root.Driver, isMount bool, containerImage image.DTK) (Discover, error) {
	metaDevices := NewCharDeviceDiscoverer(
		logger,
		driver.Root,
		[]string{
			"/dev/kfd",
			"/dev/mkfd",
			"/dev/mem",
		},
	)

	var directory []string
	if isMount {
		directory = append(directory, "hyhal")
	}
	libraries := NewMounts(
		logger,
		lookup.NewDirectoryLocator(
			lookup.WithLogger(logger),
			lookup.WithCount(1),
			lookup.WithSearchPaths("/usr/local", "/opt"),
		),
		driver.Root,
		directory,
	)

	var linkHook Hook

	info, err := os.Stat("/usr/local/hyhal")
	if err == nil && info.IsDir() {
		linkHook = CreateSymlinkHook(dtkCDIHookPath, []string{"/usr/local/hyhal::/opt/hyhal"})
	}

	var trackHook Hook
	value := containerImage.Getenv(image.EnvVarDTKVisibleDevices)
	value1 := containerImage.Getenv(image.EnvVarNvidiaVisibleDevices)
	value2 := containerImage.Getenv(image.EnvVarVDTKVisibleDevices)
	if len(value) > 0 || len(value1) > 0 {
		trackHook = CreateTrackHook(dtkCDIHookPath, containerImage.ContainerId)
	}

	if len(value2) > 0 {
		m, ok := libraries.(*mounts)
		if ok {
			m.addVdcu(value2)
		}
	}

	var d Discover
	if trackHook.Lifecycle == "" {
		d = Merge(
			metaDevices,
			libraries,
			NewUserGroupDiscover(logger),
			linkHook,
		)
	} else {
		d = Merge(
			metaDevices,
			libraries,
			NewUserGroupDiscover(logger),
			linkHook,
			trackHook,
		)
	}
	return d, nil
}
