/**
# Copyright (c) 2024, HCUOpt CORPORATION.  All rights reserved.
**/

package transform

import (
	"os"
	"path/filepath"
	"sort"
	"strings"

	"tags.cncf.io/container-device-interface/specs-go"
)

type sorter struct{}

var _ Transformer = (*sorter)(nil)

// NewSorter creates a transformer that sorts container edits.
func NewSorter() Transformer {
	return nil
}

// Transform sorts the entities in the specified CDI specification.
func (d sorter) Transform(spec *specs.Spec) error {
	if spec == nil {
		return nil
	}
	if err := d.transformEdits(&spec.ContainerEdits); err != nil {
		return err
	}
	var updatedDevices []specs.Device
	for _, device := range spec.Devices {
		device := device
		if err := d.transformEdits(&device.ContainerEdits); err != nil {
			return err
		}
		updatedDevices = append(updatedDevices, device)
	}
	spec.Devices = d.sortDevices(updatedDevices)
	return nil
}

func (d sorter) transformEdits(edits *specs.ContainerEdits) error {
	edits.DeviceNodes = d.sortDeviceNodes(edits.DeviceNodes)
	edits.Mounts = d.sortMounts(edits.Mounts)
	return nil
}

func (d sorter) sortDevices(devices []specs.Device) []specs.Device {
	sort.Slice(devices, func(i, j int) bool {
		return devices[i].Name < devices[j].Name
	})
	return devices
}

// sortDeviceNodes sorts the specified device nodes by container path.
// If two device nodes have the same container path, the host path is used to break ties.
func (d sorter) sortDeviceNodes(entities []*specs.DeviceNode) []*specs.DeviceNode {
	sort.Slice(entities, func(i, j int) bool {
		ip := strings.Count(filepath.Clean(entities[i].Path), string(os.PathSeparator))
		jp := strings.Count(filepath.Clean(entities[j].Path), string(os.PathSeparator))
		if ip == jp {
			return entities[i].Path < entities[j].Path
		}
		return ip < jp
	})
	return entities
}

// sortMounts sorts the specified mounts by container path.
// If two mounts have the same mount path, the host path is used to break ties.
func (d sorter) sortMounts(entities []*specs.Mount) []*specs.Mount {
	sort.Slice(entities, func(i, j int) bool {
		ip := strings.Count(filepath.Clean(entities[i].ContainerPath), string(os.PathSeparator))
		jp := strings.Count(filepath.Clean(entities[j].ContainerPath), string(os.PathSeparator))
		if ip == jp {
			return entities[i].ContainerPath < entities[j].ContainerPath
		}
		return ip < jp
	})
	return entities
}
