spec.go 2.71 KB
Newer Older
songlinfeng's avatar
songlinfeng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/**
# Copyright (c) 2024, HCUOpt CORPORATION.  All rights reserved.
**/

package spec

import (
	"dtk-container-toolkit/pkg/c3000cdi/transform"
	"fmt"
	"io"
	"os"
	"path/filepath"

	"tags.cncf.io/container-device-interface/pkg/cdi"
	"tags.cncf.io/container-device-interface/specs-go"
)

type spec struct {
	*specs.Spec
	format          string
	permissions     os.FileMode
	transformOnSave transform.Transformer
}

var _ Interface = (*spec)(nil)

// New creates a new spec with the specified options.
func New(opts ...Option) (Interface, error) {
	return newBuilder(opts...).Build()
}

// Save writes the spec to the specified path and overwrites the file if it exists.
func (s *spec) Save(path string) error {
	if s.transformOnSave != nil {
		err := s.transformOnSave.Transform(s.Raw())
		if err != nil {
			return fmt.Errorf("error applying transform: %w", err)
		}
	}
	path, err := s.normalizePath(path)
	if err != nil {
		return fmt.Errorf("failed to normalize path: %w", err)
	}

	specDir := filepath.Dir(path)
	cache, _ := cdi.NewCache(
		cdi.WithAutoRefresh(false),
		cdi.WithSpecDirs(specDir),
	)
	if err := cache.WriteSpec(s.Raw(), filepath.Base(path)); err != nil {
		return fmt.Errorf("failed to write spec: %w", err)
	}

	if err := os.Chmod(path, s.permissions); err != nil {
		return fmt.Errorf("failed to set permissions on spec file: %w", err)
	}

	return nil
}

// WriteTo writes the spec to the specified writer.
func (s *spec) WriteTo(w io.Writer) (int64, error) {
	name, err := cdi.GenerateNameForSpec(s.Raw())
	if err != nil {
		return 0, err
	}

	path, _ := s.normalizePath(name)
	tmpFile, err := os.CreateTemp("", "*"+filepath.Base(path))
	if err != nil {
		return 0, err
	}
	defer os.Remove(tmpFile.Name())

	if err := s.Save(tmpFile.Name()); err != nil {
		return 0, err
	}

	err = tmpFile.Close()
	if err != nil {
		return 0, fmt.Errorf("failed to close temporary file: %w", err)
	}

	r, err := os.Open(tmpFile.Name())
	if err != nil {
		return 0, fmt.Errorf("failed to open temporary file: %w", err)
	}
	defer r.Close()

	return io.Copy(w, r)
}

// Raw returns a pointer to the raw spec.
func (s *spec) Raw() *specs.Spec {
	return s.Spec
}

// normalizePath ensures that the specified path has a supported extension
func (s *spec) normalizePath(path string) (string, error) {
	if ext := filepath.Ext(path); ext != ".yaml" && ext != ".json" {
		path += s.extension()
	}

	if filepath.Clean(filepath.Dir(path)) == "." {
		pwd, err := os.Getwd()
		if err != nil {
			return path, fmt.Errorf("failed to get current working directory: %v", err)
		}
		path = filepath.Join(pwd, path)
	}

	return path, nil
}

func (s *spec) extension() string {
	switch s.format {
	case FormatJSON:
		return ".json"
	case FormatYAML:
		return ".yaml"
	}

	return ".yaml"
}