Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
8853663e
Commit
8853663e
authored
Aug 08, 2025
by
wooway777
Browse files
issue/21 - Improved viewReshaped implementation and calls
parent
dd5dec97
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
16 deletions
+63
-16
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+4
-3
src/tensor.hpp
src/tensor.hpp
+1
-1
src/tensor/tensor.cpp
src/tensor/tensor.cpp
+58
-12
No files found.
src/models/jiuge/jiuge.cpp
View file @
8853663e
...
@@ -141,6 +141,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -141,6 +141,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
result_buf
=
Tensor
::
buffer
(
INFINI_DTYPE_I64
,
{
nreq
},
rsrc
.
memory_pool
);
auto
result_buf
=
Tensor
::
buffer
(
INFINI_DTYPE_I64
,
{
nreq
},
rsrc
.
memory_pool
);
auto
result_cpu
=
std
::
vector
<
int64_t
>
(
nreq
);
auto
result_cpu
=
std
::
vector
<
int64_t
>
(
nreq
);
auto
qkv_rope
=
qkv_buf
->
viewReshaped
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
// Prepare inputs
// Prepare inputs
auto
batch_pos_ids
=
std
::
vector
<
uint32_t
>
(
ntok
);
auto
batch_pos_ids
=
std
::
vector
<
uint32_t
>
(
ntok
);
size_t
req_start
=
0
;
size_t
req_start
=
0
;
...
@@ -181,7 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -181,7 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
qk_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
,
max_qk_size
},
rsrc
.
memory_pool
);
auto
qk_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
,
max_qk_size
},
rsrc
.
memory_pool
);
auto
rearrange_q_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
rearrange_q_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
q_rearrange
=
rearrange_q_buf
->
viewReshaped
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
attn_val_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
attn_val_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
attn_val_gemm
=
attn_val_buf
->
viewReshaped
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
// MLP buffers
// MLP buffers
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
...
@@ -198,7 +202,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -198,7 +202,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
}
linear
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
0.0
,
has_qkv_bias
?
qkv_buf
:
nullptr
);
linear
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
0.0
,
has_qkv_bias
?
qkv_buf
:
nullptr
);
// rope
// rope
auto
qkv_rope
=
qkv_buf
->
viewReshaped
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
rope
(
qkv_rope
->
slice
(
1
,
0
,
nh
),
qkv_rope
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
qkv_rope
->
slice
(
1
,
0
,
nh
),
qkv_rope
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
...
@@ -217,7 +220,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -217,7 +220,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange
(
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
k
);
rearrange
(
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
k
);
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
// qk
// qk
auto
q_rearrange
=
rearrange_q_buf
->
viewReshaped
({
nkvh
,
ngroup
,
seq_len
,
dh
});
rearrange
(
q_rearrange
,
q
);
rearrange
(
q_rearrange
,
q
);
auto
qk_gemm
=
qk_buf
->
viewReshaped
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
qk_gemm
=
qk_buf
->
viewReshaped
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
...
@@ -228,7 +230,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -228,7 +230,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.0
,
0.0
,
nullptr
);
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.0
,
0.0
,
nullptr
);
// rearrange attn val
// rearrange attn val
auto
attn_val_gemm
=
attn_val_buf
->
viewReshaped
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
rearrange
(
o
,
attn_val_gemm
);
rearrange
(
o
,
attn_val_gemm
);
token_offset
+=
seq_len
;
token_offset
+=
seq_len
;
...
...
src/tensor.hpp
View file @
8853663e
...
@@ -130,7 +130,7 @@ public:
...
@@ -130,7 +130,7 @@ public:
std
::
shared_ptr
<
Tensor
>
view
()
const
;
std
::
shared_ptr
<
Tensor
>
view
()
const
;
std
::
shared_ptr
<
Tensor
>
view
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
const
;
std
::
shared_ptr
<
Tensor
>
view
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
const
;
std
::
shared_ptr
<
Tensor
>
viewReshaped
(
const
std
::
vector
<
size_t
>
new_shape
)
const
;
std
::
shared_ptr
<
Tensor
>
viewReshaped
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
;
~
Tensor
();
~
Tensor
();
};
};
...
...
src/tensor/tensor.cpp
View file @
8853663e
...
@@ -274,23 +274,69 @@ std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> new_shape, const
...
@@ -274,23 +274,69 @@ std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> new_shape, const
return
tensor
;
return
tensor
;
}
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
viewReshaped
(
const
std
::
vector
<
size_t
>
new_shape
)
const
{
std
::
shared_ptr
<
Tensor
>
Tensor
::
viewReshaped
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
{
// Create a copy of the current shape and strides
// Calculate total elements in current and new shape
auto
current_shape
=
_desc
->
shape
();
size_t
current_elements
=
std
::
accumulate
(
_desc
->
shape
().
begin
(),
_desc
->
shape
().
end
(),
// Start with the current tensor
1
,
std
::
multiplies
<
size_t
>
());
auto
result
=
this
->
view
();
size_t
new_elements
=
std
::
accumulate
(
new_shape
.
begin
(),
new_shape
.
end
(),
1
,
std
::
multiplies
<
size_t
>
());
ASSERT_EQ
(
current_elements
,
new_elements
);
const
auto
&
old_shape
=
_desc
->
shape
();
const
auto
&
old_strides
=
_desc
->
strides
();
// Special case: empty tensor
if
(
current_elements
==
0
)
{
auto
result
=
std
::
make_shared
<
Tensor
>
();
result
->
_storage
=
this
->
_storage
;
result
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
new_shape
,
{});
result
->
_offset
=
this
->
_offset
;
return
result
;
}
// Step 1: Merge all dimensions (if there are more than 1)
// Special case: scalar to scalar
if
(
current_shape
.
size
()
>
1
)
{
if
(
old_shape
.
empty
()
&&
new_shape
.
empty
())
{
result
=
result
->
dimMerge
(
0
,
current_shape
.
size
()
-
1
);
auto
result
=
std
::
make_shared
<
Tensor
>
();
result
->
_storage
=
this
->
_storage
;
result
->
_desc
=
this
->
_desc
;
result
->
_offset
=
this
->
_offset
;
return
result
;
}
}
// Step 2: Split into the new shape
// Compute new strides
if
(
new_shape
.
size
()
>
1
)
{
std
::
vector
<
ptrdiff_t
>
new_strides
;
result
=
result
->
dimSplit
(
0
,
new_shape
);
if
(
!
new_shape
.
empty
())
{
new_strides
.
resize
(
new_shape
.
size
());
// Compute strides for the new shape while preserving memory layout
// Start from the rightmost dimension
new_strides
.
back
()
=
old_strides
.
back
();
for
(
int
i
=
new_shape
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
new_strides
[
i
]
=
new_strides
[
i
+
1
]
*
new_shape
[
i
+
1
];
}
// Verify the new strides are compatible with the old memory layout
size_t
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
old_shape
.
size
();
++
i
)
{
offset
+=
(
old_shape
[
i
]
-
1
)
*
old_strides
[
i
];
}
size_t
new_offset
=
0
;
for
(
size_t
i
=
0
;
i
<
new_shape
.
size
();
++
i
)
{
new_offset
+=
(
new_shape
[
i
]
-
1
)
*
new_strides
[
i
];
}
ASSERT_EQ
(
offset
,
new_offset
);
}
}
// Create and return the reshaped tensor
auto
result
=
std
::
make_shared
<
Tensor
>
();
result
->
_storage
=
this
->
_storage
;
result
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
new_shape
,
new_strides
);
result
->
_offset
=
this
->
_offset
;
return
result
;
return
result
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment