Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
b48083f3
Unverified
Commit
b48083f3
authored
Nov 13, 2025
by
Michael Yang
Committed by
GitHub
Nov 13, 2025
Browse files
ml: add slice operation (#12870)
* slice * chunk, chunksections
parent
482bec82
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
773 additions
and
1 deletion
+773
-1
ml/backend.go
ml/backend.go
+4
-0
ml/backend/ggml/ggml.go
ml/backend/ggml/ggml.go
+63
-0
ml/backend/ggml/ggml_test.go
ml/backend/ggml/ggml_test.go
+706
-1
No files found.
ml/backend.go
View file @
b48083f3
...
...
@@ -198,6 +198,10 @@ type Tensor interface {
Copy
(
ctx
Context
,
t2
Tensor
)
Tensor
Duplicate
(
ctx
Context
)
Tensor
Slice
(
ctx
Context
,
dim
,
low
,
high
,
step
int
)
Tensor
Chunk
(
ctx
Context
,
dim
int
,
size
int
)
[]
Tensor
ChunkSections
(
ctx
Context
,
dim
int
,
sections
...
int
)
[]
Tensor
TopK
(
ctx
Context
,
k
int
)
Tensor
Argsort
(
ctx
Context
)
Tensor
Mean
(
ctx
Context
)
Tensor
...
...
ml/backend/ggml/ggml.go
View file @
b48083f3
...
...
@@ -1738,3 +1738,66 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
t
:
C
.
ggml_clamp
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
float
(
min
),
C
.
float
(
max
)),
}
}
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
// Slice panics if the dimension is invalid or the slice parameters are out of range.
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
func
(
t
*
Tensor
)
Slice
(
ctx
ml
.
Context
,
dim
int
,
low
,
high
,
step
int
)
ml
.
Tensor
{
if
dim
<
0
||
dim
>=
C
.
GGML_MAX_DIMS
{
panic
(
"invalid dimension"
)
}
else
if
low
<
0
||
high
>
t
.
Dim
(
dim
)
||
low
>=
high
||
step
<
1
{
panic
(
"invalid slice parameters"
)
}
if
dim
==
0
&&
step
>
1
{
// dim=0,step>1 is a special case so handle it here first
return
t
.
View
(
ctx
,
low
*
t
.
Stride
(
0
),
1
,
step
*
t
.
Stride
(
0
),
(
high
-
low
+
1
)
/
step
,
t
.
Stride
(
1
),
t
.
Dim
(
1
),
// preserve dim 3 by merging it into dim 2
t
.
Stride
(
2
),
t
.
Dim
(
2
)
*
t
.
Dim
(
3
),
)
.
Contiguous
(
ctx
,
(
high
-
low
+
1
)
/
step
,
t
.
Dim
(
1
),
t
.
Dim
(
2
),
t
.
Dim
(
3
))
}
args
:=
[]
int
{
low
*
t
.
Stride
(
dim
),
t
.
Dim
(
0
),
t
.
Stride
(
1
),
t
.
Dim
(
1
),
t
.
Stride
(
2
),
t
.
Dim
(
2
),
t
.
Stride
(
3
),
t
.
Dim
(
3
),
}
if
step
==
1
{
args
[
dim
*
2
+
1
]
=
high
-
low
return
t
.
View
(
ctx
,
args
[
0
],
args
[
1
:
]
...
)
}
else
{
args
[
dim
*
2
]
=
step
*
t
.
Stride
(
dim
)
args
[
dim
*
2
+
1
]
=
(
high
-
low
+
1
)
/
step
return
t
.
View
(
ctx
,
args
[
0
],
args
[
1
:
]
...
)
}
}
// Chunk the tensor into chunk sized tensors along dim. Each sub-tensor is a view of
// the original.
func
(
t
*
Tensor
)
Chunk
(
ctx
ml
.
Context
,
dim
,
chunk
int
)
[]
ml
.
Tensor
{
sections
:=
make
([]
int
,
0
,
t
.
Dim
(
dim
)
/
chunk
+
1
)
for
rest
:=
t
.
Dim
(
dim
);
rest
>
0
;
rest
-=
chunk
{
sections
=
append
(
sections
,
min
(
chunk
,
rest
))
}
return
t
.
ChunkSections
(
ctx
,
dim
,
sections
...
)
}
// ChunkSections split the tensor into section sized tensors along dim. Each sub-tensor is a
// view of the original. The size of the dim must equal the sum of sections.
func
(
t
*
Tensor
)
ChunkSections
(
ctx
ml
.
Context
,
dim
int
,
sections
...
int
)
[]
ml
.
Tensor
{
var
offset
int
s
:=
make
([]
ml
.
Tensor
,
len
(
sections
))
for
i
,
section
:=
range
sections
{
s
[
i
]
=
t
.
Slice
(
ctx
,
dim
,
offset
,
offset
+
section
,
1
)
offset
+=
section
}
if
offset
!=
t
.
Dim
(
dim
)
{
panic
(
"sections do not sum to tensor dimension"
)
}
return
s
}
ml/backend/ggml/ggml_test.go
View file @
b48083f3
...
...
@@ -2,6 +2,7 @@ package ggml
import
(
"errors"
"fmt"
"os"
"testing"
...
...
@@ -368,10 +369,714 @@ func TestPermute(t *testing.T) {
for
_
,
tt
:=
range
cases
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
ctx
:=
setup
(
t
)
got
:=
tt
.
input
(
ctx
)
.
Permute
(
ctx
,
tt
.
shape
...
)
.
Contiguous
(
ctx
)
got
:=
tt
.
input
(
ctx
)
.
Permute
(
ctx
,
tt
.
shape
...
)
got
=
got
.
Contiguous
(
ctx
)
if
diff
:=
cmp
.
Diff
(
tt
.
want
(
ctx
),
got
,
EquateTensors
(
ctx
));
diff
!=
""
{
t
.
Errorf
(
"Permute() result mismatch (-want +got):
\n
%s"
,
diff
)
}
})
}
}
func
TestSlice
(
t
*
testing
.
T
)
{
cases
:=
[]
struct
{
dim
int
low
int
high
int
step
int
input
func
(
ml
.
Context
)
ml
.
Tensor
want
func
(
ml
.
Context
)
ml
.
Tensor
}{
{
dim
:
0
,
low
:
1
,
high
:
3
,
step
:
1
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
4
*
4
*
4
*
4
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
4
,
4
,
4
)
},
want
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
1
,
2
,
5
,
6
,
9
,
10
,
13
,
14
,
17
,
18
,
21
,
22
,
25
,
26
,
29
,
30
,
33
,
34
,
37
,
38
,
41
,
42
,
45
,
46
,
49
,
50
,
53
,
54
,
57
,
58
,
61
,
62
,
65
,
66
,
69
,
70
,
73
,
74
,
77
,
78
,
81
,
82
,
85
,
86
,
89
,
90
,
93
,
94
,
97
,
98
,
101
,
102
,
105
,
106
,
109
,
110
,
113
,
114
,
117
,
118
,
121
,
122
,
125
,
126
,
129
,
130
,
133
,
134
,
137
,
138
,
141
,
142
,
145
,
146
,
149
,
150
,
153
,
154
,
157
,
158
,
161
,
162
,
165
,
166
,
169
,
170
,
173
,
174
,
177
,
178
,
181
,
182
,
185
,
186
,
189
,
190
,
193
,
194
,
197
,
198
,
201
,
202
,
205
,
206
,
209
,
210
,
213
,
214
,
217
,
218
,
221
,
222
,
225
,
226
,
229
,
230
,
233
,
234
,
237
,
238
,
241
,
242
,
245
,
246
,
249
,
250
,
253
,
254
,
},
2
,
4
,
4
,
4
)
},
},
{
dim
:
1
,
low
:
1
,
high
:
3
,
step
:
1
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
4
*
4
*
4
*
4
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
4
,
4
,
4
)
},
want
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
,
68
,
69
,
70
,
71
,
72
,
73
,
74
,
75
,
84
,
85
,
86
,
87
,
88
,
89
,
90
,
91
,
100
,
101
,
102
,
103
,
104
,
105
,
106
,
107
,
116
,
117
,
118
,
119
,
120
,
121
,
122
,
123
,
132
,
133
,
134
,
135
,
136
,
137
,
138
,
139
,
148
,
149
,
150
,
151
,
152
,
153
,
154
,
155
,
164
,
165
,
166
,
167
,
168
,
169
,
170
,
171
,
180
,
181
,
182
,
183
,
184
,
185
,
186
,
187
,
196
,
197
,
198
,
199
,
200
,
201
,
202
,
203
,
212
,
213
,
214
,
215
,
216
,
217
,
218
,
219
,
228
,
229
,
230
,
231
,
232
,
233
,
234
,
235
,
244
,
245
,
246
,
247
,
248
,
249
,
250
,
251
,
},
4
,
2
,
4
,
4
)
},
},
{
dim
:
2
,
low
:
1
,
high
:
3
,
step
:
1
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
4
*
4
*
4
*
4
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
4
,
4
,
4
)
},
want
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
80
,
81
,
82
,
83
,
84
,
85
,
86
,
87
,
88
,
89
,
90
,
91
,
92
,
93
,
94
,
95
,
96
,
97
,
98
,
99
,
100
,
101
,
102
,
103
,
104
,
105
,
106
,
107
,
108
,
109
,
110
,
111
,
144
,
145
,
146
,
147
,
148
,
149
,
150
,
151
,
152
,
153
,
154
,
155
,
156
,
157
,
158
,
159
,
160
,
161
,
162
,
163
,
164
,
165
,
166
,
167
,
168
,
169
,
170
,
171
,
172
,
173
,
174
,
175
,
208
,
209
,
210
,
211
,
212
,
213
,
214
,
215
,
216
,
217
,
218
,
219
,
220
,
221
,
222
,
223
,
224
,
225
,
226
,
227
,
228
,
229
,
230
,
231
,
232
,
233
,
234
,
235
,
236
,
237
,
238
,
239
,
},
4
,
4
,
2
,
4
)
},
},
{
dim
:
3
,
low
:
1
,
high
:
3
,
step
:
1
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
4
*
4
*
4
*
4
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
4
,
4
,
4
)
},
want
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
64
,
65
,
66
,
67
,
68
,
69
,
70
,
71
,
72
,
73
,
74
,
75
,
76
,
77
,
78
,
79
,
80
,
81
,
82
,
83
,
84
,
85
,
86
,
87
,
88
,
89
,
90
,
91
,
92
,
93
,
94
,
95
,
96
,
97
,
98
,
99
,
100
,
101
,
102
,
103
,
104
,
105
,
106
,
107
,
108
,
109
,
110
,
111
,
112
,
113
,
114
,
115
,
116
,
117
,
118
,
119
,
120
,
121
,
122
,
123
,
124
,
125
,
126
,
127
,
128
,
129
,
130
,
131
,
132
,
133
,
134
,
135
,
136
,
137
,
138
,
139
,
140
,
141
,
142
,
143
,
144
,
145
,
146
,
147
,
148
,
149
,
150
,
151
,
152
,
153
,
154
,
155
,
156
,
157
,
158
,
159
,
160
,
161
,
162
,
163
,
164
,
165
,
166
,
167
,
168
,
169
,
170
,
171
,
172
,
173
,
174
,
175
,
176
,
177
,
178
,
179
,
180
,
181
,
182
,
183
,
184
,
185
,
186
,
187
,
188
,
189
,
190
,
191
,
},
4
,
4
,
4
,
2
)
},
},
{
dim
:
0
,
low
:
0
,
high
:
4
,
step
:
2
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
4
*
4
*
4
*
4
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
4
,
4
,
4
)
},
want
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
2
,
4
,
6
,
8
,
10
,
12
,
14
,
16
,
18
,
20
,
22
,
24
,
26
,
28
,
30
,
32
,
34
,
36
,
38
,
40
,
42
,
44
,
46
,
48
,
50
,
52
,
54
,
56
,
58
,
60
,
62
,
64
,
66
,
68
,
70
,
72
,
74
,
76
,
78
,
80
,
82
,
84
,
86
,
88
,
90
,
92
,
94
,
96
,
98
,
100
,
102
,
104
,
106
,
108
,
110
,
112
,
114
,
116
,
118
,
120
,
122
,
124
,
126
,
128
,
130
,
132
,
134
,
136
,
138
,
140
,
142
,
144
,
146
,
148
,
150
,
152
,
154
,
156
,
158
,
160
,
162
,
164
,
166
,
168
,
170
,
172
,
174
,
176
,
178
,
180
,
182
,
184
,
186
,
188
,
190
,
192
,
194
,
196
,
198
,
200
,
202
,
204
,
206
,
208
,
210
,
212
,
214
,
216
,
218
,
220
,
222
,
224
,
226
,
228
,
230
,
232
,
234
,
236
,
238
,
240
,
242
,
244
,
246
,
248
,
250
,
252
,
254
,
},
2
,
4
,
4
,
4
)
},
},
{
dim
:
1
,
low
:
0
,
high
:
4
,
step
:
2
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
4
*
4
*
4
*
4
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
4
,
4
,
4
)
},
want
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
1
,
2
,
3
,
8
,
9
,
10
,
11
,
16
,
17
,
18
,
19
,
24
,
25
,
26
,
27
,
32
,
33
,
34
,
35
,
40
,
41
,
42
,
43
,
48
,
49
,
50
,
51
,
56
,
57
,
58
,
59
,
64
,
65
,
66
,
67
,
72
,
73
,
74
,
75
,
80
,
81
,
82
,
83
,
88
,
89
,
90
,
91
,
96
,
97
,
98
,
99
,
104
,
105
,
106
,
107
,
112
,
113
,
114
,
115
,
120
,
121
,
122
,
123
,
128
,
129
,
130
,
131
,
136
,
137
,
138
,
139
,
144
,
145
,
146
,
147
,
152
,
153
,
154
,
155
,
160
,
161
,
162
,
163
,
168
,
169
,
170
,
171
,
176
,
177
,
178
,
179
,
184
,
185
,
186
,
187
,
192
,
193
,
194
,
195
,
200
,
201
,
202
,
203
,
208
,
209
,
210
,
211
,
216
,
217
,
218
,
219
,
224
,
225
,
226
,
227
,
232
,
233
,
234
,
235
,
240
,
241
,
242
,
243
,
248
,
249
,
250
,
251
,
},
4
,
2
,
4
,
4
)
},
},
{
dim
:
2
,
low
:
0
,
high
:
4
,
step
:
2
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
4
*
4
*
4
*
4
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
4
,
4
,
4
)
},
want
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
64
,
65
,
66
,
67
,
68
,
69
,
70
,
71
,
72
,
73
,
74
,
75
,
76
,
77
,
78
,
79
,
96
,
97
,
98
,
99
,
100
,
101
,
102
,
103
,
104
,
105
,
106
,
107
,
108
,
109
,
110
,
111
,
128
,
129
,
130
,
131
,
132
,
133
,
134
,
135
,
136
,
137
,
138
,
139
,
140
,
141
,
142
,
143
,
160
,
161
,
162
,
163
,
164
,
165
,
166
,
167
,
168
,
169
,
170
,
171
,
172
,
173
,
174
,
175
,
192
,
193
,
194
,
195
,
196
,
197
,
198
,
199
,
200
,
201
,
202
,
203
,
204
,
205
,
206
,
207
,
224
,
225
,
226
,
227
,
228
,
229
,
230
,
231
,
232
,
233
,
234
,
235
,
236
,
237
,
238
,
239
,
},
4
,
4
,
2
,
4
)
},
},
{
dim
:
3
,
low
:
0
,
high
:
4
,
step
:
2
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
4
*
4
*
4
*
4
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
4
,
4
,
4
)
},
want
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
,
50
,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
,
60
,
61
,
62
,
63
,
128
,
129
,
130
,
131
,
132
,
133
,
134
,
135
,
136
,
137
,
138
,
139
,
140
,
141
,
142
,
143
,
144
,
145
,
146
,
147
,
148
,
149
,
150
,
151
,
152
,
153
,
154
,
155
,
156
,
157
,
158
,
159
,
160
,
161
,
162
,
163
,
164
,
165
,
166
,
167
,
168
,
169
,
170
,
171
,
172
,
173
,
174
,
175
,
176
,
177
,
178
,
179
,
180
,
181
,
182
,
183
,
184
,
185
,
186
,
187
,
188
,
189
,
190
,
191
,
},
4
,
4
,
4
,
2
)
},
},
}
for
_
,
tt
:=
range
cases
{
name
:=
fmt
.
Sprintf
(
"dim=%d,low=%d,high=%d,step=%d"
,
tt
.
dim
,
tt
.
low
,
tt
.
high
,
tt
.
step
)
t
.
Run
(
name
,
func
(
t
*
testing
.
T
)
{
ctx
:=
setup
(
t
)
got
:=
tt
.
input
(
ctx
)
.
Slice
(
ctx
,
tt
.
dim
,
tt
.
low
,
tt
.
high
,
tt
.
step
)
got
=
got
.
Contiguous
(
ctx
)
if
diff
:=
cmp
.
Diff
(
tt
.
want
(
ctx
),
got
,
EquateTensors
(
ctx
));
diff
!=
""
{
t
.
Errorf
(
"Slice() result mismatch (-want +got):
\n
%s"
,
diff
)
}
})
}
}
func
TestSplitSections
(
t
*
testing
.
T
)
{
cases
:=
[]
struct
{
dim
int
sections
[]
int
input
func
(
ml
.
Context
)
ml
.
Tensor
want
[]
func
(
ml
.
Context
)
ml
.
Tensor
}{
{
dim
:
0
,
sections
:
[]
int
{
1
,
1
,
1
},
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
12
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
3
,
4
)
},
want
:
[]
func
(
ml
.
Context
)
ml
.
Tensor
{
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
3
,
6
,
9
},
1
,
4
)
},
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
1
,
4
,
7
,
10
},
1
,
4
)
},
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
2
,
5
,
8
,
11
},
1
,
4
)
},
},
},
{
dim
:
1
,
sections
:
[]
int
{
1
,
3
},
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
12
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
3
,
4
)
},
want
:
[]
func
(
ml
.
Context
)
ml
.
Tensor
{
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
1
,
2
},
3
,
1
)
},
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
},
3
,
3
)
},
},
},
{
dim
:
0
,
sections
:
[]
int
{
2
,
2
},
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
12
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
3
)
},
want
:
[]
func
(
ml
.
Context
)
ml
.
Tensor
{
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
1
,
4
,
5
,
8
,
9
,
},
2
,
3
)
},
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
2
,
3
,
6
,
7
,
10
,
11
,
},
2
,
3
)
},
},
},
{
dim
:
1
,
sections
:
[]
int
{
1
,
2
},
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
12
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
4
,
3
)
},
want
:
[]
func
(
ml
.
Context
)
ml
.
Tensor
{
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
1
,
2
,
3
},
4
,
1
)
},
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
},
4
,
2
)
},
},
},
}
for
_
,
tt
:=
range
cases
{
t
.
Run
(
fmt
.
Sprintf
(
"sections=%v"
,
tt
.
sections
),
func
(
t
*
testing
.
T
)
{
ctx
:=
setup
(
t
)
got
:=
tt
.
input
(
ctx
)
.
ChunkSections
(
ctx
,
tt
.
dim
,
tt
.
sections
...
)
for
i
:=
range
got
{
got
[
i
]
=
got
[
i
]
.
Contiguous
(
ctx
)
}
ctx
.
Forward
(
got
...
)
.
Compute
(
got
...
)
for
i
,
want
:=
range
tt
.
want
{
if
diff
:=
cmp
.
Diff
(
want
(
ctx
),
got
[
i
],
EquateTensors
(
ctx
));
diff
!=
""
{
t
.
Errorf
(
"SplitSections() section %d mismatch (-want +got):
\n
%s"
,
i
,
diff
)
}
}
})
}
}
func
TestChunk
(
t
*
testing
.
T
)
{
cases
:=
[]
struct
{
dim
int
chunk
int
input
func
(
ml
.
Context
)
ml
.
Tensor
want
[]
func
(
ml
.
Context
)
ml
.
Tensor
}{
{
dim
:
0
,
chunk
:
1
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
12
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
3
,
4
)
},
want
:
[]
func
(
ml
.
Context
)
ml
.
Tensor
{
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
3
,
6
,
9
},
1
,
4
)
},
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
1
,
4
,
7
,
10
},
1
,
4
)
},
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
2
,
5
,
8
,
11
},
1
,
4
)
},
},
},
{
dim
:
1
,
chunk
:
2
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
12
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
3
,
4
)
},
want
:
[]
func
(
ml
.
Context
)
ml
.
Tensor
{
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
1
,
2
,
3
,
4
,
5
,
},
3
,
2
)
},
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
6
,
7
,
8
,
9
,
10
,
11
,
},
3
,
2
)
},
},
},
{
dim
:
0
,
chunk
:
2
,
input
:
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
Arange
(
0
,
12
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
3
,
4
)
},
want
:
[]
func
(
ml
.
Context
)
ml
.
Tensor
{
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
0
,
1
,
3
,
4
,
6
,
7
,
9
,
10
,
},
2
,
4
)
},
func
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
ctx
.
FromFloats
([]
float32
{
2
,
5
,
8
,
11
,
},
1
,
4
)
},
},
},
}
for
_
,
tt
:=
range
cases
{
t
.
Run
(
fmt
.
Sprintf
(
"dim=%d,chunk=%d"
,
tt
.
dim
,
tt
.
chunk
),
func
(
t
*
testing
.
T
)
{
ctx
:=
setup
(
t
)
got
:=
tt
.
input
(
ctx
)
.
Chunk
(
ctx
,
tt
.
dim
,
tt
.
chunk
)
for
i
:=
range
got
{
got
[
i
]
=
got
[
i
]
.
Contiguous
(
ctx
)
}
ctx
.
Forward
(
got
...
)
.
Compute
(
got
...
)
for
i
,
want
:=
range
tt
.
want
{
if
diff
:=
cmp
.
Diff
(
want
(
ctx
),
got
[
i
],
EquateTensors
(
ctx
));
diff
!=
""
{
t
.
Errorf
(
"Split() section %d mismatch (-want +got):
\n
%s"
,
i
,
diff
)
}
}
})
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment