Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
21ef8820
Commit
21ef8820
authored
Aug 11, 2025
by
wooway777
Browse files
issue/21 - fixed view() implementation
parent
b3275d7c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
109 deletions
+44
-109
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+6
-6
src/tensor/tensor.cpp
src/tensor/tensor.cpp
+38
-103
No files found.
src/models/jiuge/jiuge.cpp
View file @
21ef8820
...
...
@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
result_buf
=
Tensor
::
buffer
(
INFINI_DTYPE_I64
,
{
nreq
},
rsrc
.
memory_pool
);
auto
result_cpu
=
std
::
vector
<
int64_t
>
(
nreq
);
auto
qkv_rope
=
qkv_buf
->
view
_as
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
auto
qkv_rope
=
qkv_buf
->
view
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
// Prepare inputs
auto
batch_pos_ids
=
std
::
vector
<
uint32_t
>
(
ntok
);
...
...
@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
qk_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
,
max_qk_size
},
rsrc
.
memory_pool
);
auto
rearrange_q_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
q_rearrange
=
rearrange_q_buf
->
view
_as
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
q_rearrange
=
rearrange_q_buf
->
view
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
attn_val_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
attn_val_gemm
=
attn_val_buf
->
view
_as
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
attn_val_gemm
=
attn_val_buf
->
view
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
// MLP buffers
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
...
...
@@ -207,7 +207,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
past_len
=
req_pos
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
total_len
=
past_len
+
seq_len
;
auto
o
=
o_buf
->
slice
({{
0
,
token_offset
,
seq_len
}})
->
view
_as
({
seq_len
,
nkvh
,
ngroup
,
dh
})
->
permute
({
1
,
2
,
0
,
3
});
auto
o
=
o_buf
->
slice
({{
0
,
token_offset
,
seq_len
}})
->
view
({
seq_len
,
nkvh
,
ngroup
,
dh
})
->
permute
({
1
,
2
,
0
,
3
});
auto
q
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}})
->
view
({
seq_len
,
nkvh
,
ngroup
,
dh
})
->
permute
({
1
,
2
,
0
,
3
});
auto
k
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
auto
v
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
+
nkvh
,
nkvh
}});
...
...
@@ -218,11 +218,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
// qk
rearrange
(
q_rearrange
,
q
);
auto
qk_gemm
=
qk_buf
->
view
_as
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
qk_gemm
=
qk_buf
->
view
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
linear
(
qk_gemm
,
rearrange_q_buf
,
k_gemm
,
1.
/
sqrt
(
dh
),
0.0
,
nullptr
,
nullptr
);
// softmax
auto
qk_softmax
=
qk_buf
->
view
_as
({
nh
,
seq_len
,
total_len
});
auto
qk_softmax
=
qk_buf
->
view
({
nh
,
seq_len
,
total_len
});
causalSoftmax
(
qk_softmax
,
qk_softmax
);
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.0
,
0.0
,
nullptr
,
nullptr
);
...
...
src/tensor/tensor.cpp
View file @
21ef8820
...
...
@@ -273,128 +273,63 @@ size_t Tensor::seed() const {
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
{
//
Calculate total number of elements
//
Step 1: Validate total size
size_t
numel
=
1
;
for
(
auto
s
:
shape
())
{
numel
*=
s
;
for
(
size_t
dim
:
this
->
_desc
->
shape
())
{
numel
*=
dim
;
}
size_t
new_numel
=
1
;
for
(
auto
s
:
new_shape
)
{
new_numel
*=
s
;
for
(
size_t
dim
:
new_shape
)
{
new_numel
*=
dim
;
}
ASSERT
(
numel
==
new_numel
);
ASSERT
_EQ
(
numel
,
new_numel
);
// Handle empty tensors
if
(
numel
==
0
)
{
return
this
->
view_as
(
new_shape
,
{});
}
// Step 2: Get current shape and strides
const
std
::
vector
<
size_t
>
&
old_shape
=
this
->
_desc
->
shape
();
const
std
::
vector
<
ptrdiff_t
>
&
old_strides
=
this
->
_desc
->
strides
();
// Special case: view(-1) flattens the tensor
if
(
new_shape
.
size
()
==
1
&&
new_shape
[
0
]
==
static_cast
<
size_t
>
(
-
1
))
{
std
::
vector
<
size_t
>
flat_shape
=
{
numel
};
return
this
->
view_as
(
flat_shape
,
{});
}
// Step 3: Create merged shape and strides
std
::
vector
<
size_t
>
merged_shape
;
std
::
vector
<
ptrdiff_t
>
merged_strides
;
// Check for -1 in new_shape (infer dimension)
std
::
vector
<
size_t
>
inferred_shape
=
new_shape
;
size_t
infer_index
=
static_cast
<
size_t
>
(
-
1
);
size_t
known_elements
=
1
;
if
(
!
old_shape
.
empty
())
{
merged_shape
.
push_back
(
old_shape
[
0
]);
merged_strides
.
push_back
(
old_strides
[
0
]);
for
(
size_t
i
=
0
;
i
<
new
_shape
.
size
();
++
i
)
{
if
(
new_shape
[
i
]
==
static_cast
<
size_t
>
(
-
1
))
{
ASSERT
(
infer_index
==
static_cast
<
size_t
>
(
-
1
));
// Only one -1 allowed
infer_index
=
i
;
for
(
size_t
i
=
1
;
i
<
old
_shape
.
size
();
++
i
)
{
if
(
old_strides
[
i
]
*
static_cast
<
ptrdiff_t
>
(
old_shape
[
i
])
==
merged_strides
.
back
(
))
{
merged_shape
.
back
()
*=
old_shape
[
i
];
merged_strides
.
back
()
=
old_strides
[
i
]
;
}
else
{
known_elements
*=
new_shape
[
i
];
}
}
if
(
infer_index
!=
static_cast
<
size_t
>
(
-
1
))
{
ASSERT
(
numel
%
known_elements
==
0
);
inferred_shape
[
infer_index
]
=
numel
/
known_elements
;
}
// For contiguous tensors, compute standard row-major strides
if
(
this
->
isContigous
())
{
std
::
vector
<
ptrdiff_t
>
new_strides
(
inferred_shape
.
size
());
if
(
!
inferred_shape
.
empty
())
{
new_strides
.
back
()
=
1
;
for
(
int
i
=
static_cast
<
int
>
(
inferred_shape
.
size
())
-
2
;
i
>=
0
;
--
i
)
{
new_strides
[
i
]
=
new_strides
[
i
+
1
]
*
static_cast
<
ptrdiff_t
>
(
inferred_shape
[
i
+
1
]);
merged_shape
.
push_back
(
old_shape
[
i
]);
merged_strides
.
push_back
(
old_strides
[
i
]);
}
}
return
this
->
view_as
(
inferred_shape
,
new_strides
);
}
// For non-contiguous tensors
std
::
vector
<
size_t
>
old_shape
=
shape
();
std
::
vector
<
ptrdiff_t
>
old_strides
=
strides
();
std
::
vector
<
ptrdiff_t
>
new_strides
(
inferred_shape
.
size
(),
0
);
size_t
old_idx
=
old_shape
.
size
()
-
1
;
size_t
new_idx
=
inferred_shape
.
size
()
-
1
;
if
(
new_idx
!=
static_cast
<
size_t
>
(
-
1
))
{
new_strides
[
new_idx
]
=
1
;
}
while
(
old_idx
!=
static_cast
<
size_t
>
(
-
1
)
&&
new_idx
!=
static_cast
<
size_t
>
(
-
1
))
{
size_t
old_size
=
old_shape
[
old_idx
];
size_t
new_size
=
inferred_shape
[
new_idx
];
if
(
old_size
==
1
)
{
old_idx
--
;
}
else
if
(
new_size
==
1
)
{
new_strides
[
new_idx
]
=
(
new_idx
==
inferred_shape
.
size
()
-
1
)
?
1
:
new_strides
[
new_idx
+
1
];
new_idx
--
;
}
else
if
(
old_size
==
new_size
)
{
new_strides
[
new_idx
]
=
old_strides
[
old_idx
];
old_idx
--
;
new_idx
--
;
}
else
if
(
old_size
<
new_size
)
{
size_t
combined_size
=
old_size
;
ptrdiff_t
combined_stride
=
old_strides
[
old_idx
];
old_idx
--
;
// Step 4: Compute new strides by splitting merged dimensions
std
::
vector
<
ptrdiff_t
>
new_strides
(
new_shape
.
size
());
size_t
merged_idx
=
0
;
ptrdiff_t
current_stride
=
merged_strides
[
0
];
size_t
remaining_size
=
merged_shape
[
0
];
while
(
old_idx
!=
static_cast
<
size_t
>
(
-
1
)
&&
combined_size
<
new_size
)
{
ASSERT
(
static_cast
<
size_t
>
(
old_strides
[
old_idx
])
==
old_shape
[
old_idx
+
1
]
*
static_cast
<
size_t
>
(
old_strides
[
old_idx
+
1
]));
combined_size
*=
old_shape
[
old_idx
];
combined_stride
=
old_strides
[
old_idx
];
old_idx
--
;
for
(
size_t
i
=
0
;
i
<
new_shape
.
size
();
++
i
)
{
// Find which merged dimension contains this new dimension
while
(
new_shape
[
i
]
>
remaining_size
)
{
ASSERT
(
++
merged_idx
<
merged_shape
.
size
());
current_stride
=
merged_strides
[
merged_idx
];
remaining_size
=
merged_shape
[
merged_idx
];
}
ASSERT
(
combined_size
==
new_size
);
new_strides
[
new_idx
]
=
combined_stride
;
new_idx
--
;
}
else
{
size_t
remaining_size
=
old_size
/
new_size
;
ASSERT
(
old_size
%
new_size
==
0
);
new_strides
[
new_idx
]
=
old_strides
[
old_idx
]
*
static_cast
<
ptrdiff_t
>
(
remaining_size
);
new_idx
--
;
if
(
remaining_size
!=
1
)
{
if
(
new_idx
!=
static_cast
<
size_t
>
(
-
1
))
{
inferred_shape
[
new_idx
]
=
remaining_size
;
new_strides
[
new_idx
]
=
old_strides
[
old_idx
];
new_idx
--
;
}
else
{
ASSERT
(
false
);
}
}
old_idx
--
;
}
}
ASSERT_EQ
(
remaining_size
%
new_shape
[
i
],
0
);
// Fill remaining dimensions (must be size 1)
while
(
new_idx
!=
static_cast
<
size_t
>
(
-
1
))
{
ASSERT
(
inferred_shape
[
new_idx
]
==
1
);
new_strides
[
new_idx
]
=
new_strides
[
new_idx
+
1
];
new_idx
--
;
new_strides
[
i
]
=
current_stride
*
(
remaining_size
/
new_shape
[
i
]);
remaining_size
/=
new_shape
[
i
];
}
return
this
->
view_as
(
inferred
_shape
,
new_strides
);
return
this
->
view_as
(
new
_shape
,
new_strides
);
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view_as
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment