Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
e641693d
Commit
e641693d
authored
Aug 08, 2025
by
wooway777
Browse files
issue/21 - improved view implementation
parent
6998a8f1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
139 additions
and
59 deletions
+139
-59
src/models/inference_context.cpp
src/models/inference_context.cpp
+13
-1
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+8
-8
src/tensor.hpp
src/tensor.hpp
+1
-0
src/tensor/tensor.cpp
src/tensor/tensor.cpp
+117
-50
No files found.
src/models/inference_context.cpp
View file @
e641693d
...
...
@@ -192,6 +192,18 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
float
alpha
,
float
beta
,
std
::
shared_ptr
<
Tensor
>
residual
,
std
::
shared_ptr
<
Tensor
>
bias
)
{
bool
residual_flag
=
residual
!=
nullptr
;
if
(
bias
&&
!
residual
)
{
int
ndim_diff
=
c
->
ndim
()
-
1
;
ASSERT_EQ
(
bias
->
ndim
(),
1
);
ASSERT_EQ
(
bias
->
shape
()[
0
],
c
->
shape
()[
ndim_diff
]);
std
::
vector
<
ptrdiff_t
>
strides
(
ndim_diff
,
0
);
strides
.
push_back
(
bias
->
strides
()[
0
]);
rearrange
(
c
,
bias
->
view_as
(
c
->
shape
(),
strides
));
residual
=
c
;
}
if
(
residual
)
{
if
(
residual
->
data
()
==
c
->
data
())
{
if
(
beta
==
0.0
)
{
...
...
@@ -210,7 +222,7 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
gemm
(
c
,
a
,
b
,
alpha
,
beta
);
}
if
(
bias
)
{
if
(
bias
&&
residual_flag
)
{
int
ndim_diff
=
c
->
ndim
()
-
1
;
ASSERT_EQ
(
bias
->
ndim
(),
1
);
ASSERT_EQ
(
bias
->
shape
()[
0
],
c
->
shape
()[
ndim_diff
]);
...
...
src/models/jiuge/jiuge.cpp
View file @
e641693d
...
...
@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
result_buf
=
Tensor
::
buffer
(
INFINI_DTYPE_I64
,
{
nreq
},
rsrc
.
memory_pool
);
auto
result_cpu
=
std
::
vector
<
int64_t
>
(
nreq
);
auto
qkv_rope
=
qkv_buf
->
view
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
auto
qkv_rope
=
qkv_buf
->
view
_as
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
// Prepare inputs
auto
batch_pos_ids
=
std
::
vector
<
uint32_t
>
(
ntok
);
...
...
@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
qk_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
,
max_qk_size
},
rsrc
.
memory_pool
);
auto
rearrange_q_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
q_rearrange
=
rearrange_q_buf
->
view
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
q_rearrange
=
rearrange_q_buf
->
view
_as
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
attn_val_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
attn_val_gemm
=
attn_val_buf
->
view
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
attn_val_gemm
=
attn_val_buf
->
view
_as
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
// MLP buffers
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
...
...
@@ -207,8 +207,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
past_len
=
req_pos
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
total_len
=
past_len
+
seq_len
;
auto
o
=
o_buf
->
view
({
ntok
,
nh
,
dh
})
->
slice
({{
0
,
token_offset
,
seq_len
}})
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
});
auto
q
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}})
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
});
auto
o
=
o_buf
->
slice
({{
0
,
token_offset
,
seq_len
}})
->
view_as
({
seq_len
,
nkvh
,
ngroup
,
dh
})
->
permute
({
1
,
2
,
0
,
3
});
auto
q
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}})
->
view
({
seq_len
,
nkvh
,
ngroup
,
dh
})
->
permute
({
1
,
2
,
0
,
3
});
auto
k
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
auto
v
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
+
nkvh
,
nkvh
}});
...
...
@@ -218,11 +218,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
// qk
rearrange
(
q_rearrange
,
q
);
auto
qk_gemm
=
qk_buf
->
view
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
qk_gemm
=
qk_buf
->
view
_as
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
linear
(
qk_gemm
,
rearrange_q_buf
,
k_gemm
,
1.
/
sqrt
(
dh
),
0.0
,
nullptr
,
nullptr
);
// softmax
auto
qk_softmax
=
qk_buf
->
view
({
nh
,
seq_len
,
total_len
});
auto
qk_softmax
=
qk_buf
->
view
_as
({
nh
,
seq_len
,
total_len
});
causalSoftmax
(
qk_softmax
,
qk_softmax
);
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.0
,
0.0
,
nullptr
,
nullptr
);
...
...
@@ -322,7 +322,7 @@ inferBatch(struct JiugeModel *model,
void
launchDevice
(
const
JiugeMeta
&
meta
,
const
JiugeWeights
*
weights
,
DeviceResource
*
rsrc
,
InferState
&
state
,
InferRequest
&
req
,
infiniDevice_t
device
,
int
idev
,
int
ndev
,
int
dev_id
,
infinicclComm_t
comm
)
{
CacheManager
cache_manager
(
256
);
CacheManager
cache_manager
(
100
);
InferenceContext
ctx
(
rsrc
,
&
cache_manager
,
rsrc
->
stream
);
// Set the inference context for this thread
...
...
src/tensor.hpp
View file @
e641693d
...
...
@@ -130,6 +130,7 @@ public:
std
::
shared_ptr
<
Tensor
>
view
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
;
std
::
shared_ptr
<
Tensor
>
view_as
(
const
std
::
vector
<
size_t
>
&
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
&
new_strides
)
const
;
std
::
shared_ptr
<
Tensor
>
view_as
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
;
~
Tensor
();
};
...
...
src/tensor/tensor.cpp
View file @
e641693d
...
...
@@ -259,69 +259,128 @@ std::string Tensor::info() const {
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
{
// Calculate total elements in current and new shape
size_t
current_elements
=
std
::
accumulate
(
_desc
->
shape
().
begin
(),
_desc
->
shape
().
end
(),
1
,
std
::
multiplies
<
size_t
>
());
size_t
new_elements
=
std
::
accumulate
(
new_shape
.
begin
(),
new_shape
.
end
(),
1
,
std
::
multiplies
<
size_t
>
());
ASSERT_EQ
(
current_elements
,
new_elements
);
const
auto
&
old_shape
=
_desc
->
shape
();
const
auto
&
old_strides
=
_desc
->
strides
();
// Special case: empty tensor
if
(
current_elements
==
0
)
{
auto
result
=
std
::
make_shared
<
Tensor
>
();
result
->
_storage
=
this
->
_storage
;
result
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
new_shape
,
{});
result
->
_offset
=
this
->
_offset
;
return
result
;
// Calculate total number of elements
size_t
numel
=
1
;
for
(
auto
s
:
shape
())
{
numel
*=
s
;
}
// Special case: scalar to scalar
if
(
old_shape
.
empty
()
&&
new_shape
.
empty
())
{
auto
result
=
std
::
make_shared
<
Tensor
>
();
result
->
_storage
=
this
->
_storage
;
result
->
_desc
=
this
->
_desc
;
result
->
_offset
=
this
->
_offset
;
return
result
;
size_t
new_numel
=
1
;
for
(
auto
s
:
new_shape
)
{
new_numel
*=
s
;
}
// Compute new strides
std
::
vector
<
ptrdiff_t
>
new_strides
;
if
(
!
new_shape
.
empty
())
{
new_strides
.
resize
(
new_shape
.
size
());
ASSERT
(
numel
==
new_numel
);
// Compute strides for the new shape while preserving memory layout
// Start from the rightmost dimension
new_strides
.
back
()
=
old_strides
.
back
();
for
(
int
i
=
new_shape
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
new_strides
[
i
]
=
new_strides
[
i
+
1
]
*
new_shape
[
i
+
1
];
// Handle empty tensors
if
(
numel
==
0
)
{
return
this
->
view_as
(
new_shape
,
{});
}
// Special case: view(-1) flattens the tensor
if
(
new_shape
.
size
()
==
1
&&
new_shape
[
0
]
==
static_cast
<
size_t
>
(
-
1
))
{
std
::
vector
<
size_t
>
flat_shape
=
{
numel
};
return
this
->
view_as
(
flat_shape
,
{});
}
// Check for -1 in new_shape (infer dimension)
std
::
vector
<
size_t
>
inferred_shape
=
new_shape
;
size_t
infer_index
=
static_cast
<
size_t
>
(
-
1
);
size_t
known_elements
=
1
;
for
(
size_t
i
=
0
;
i
<
new_shape
.
size
();
++
i
)
{
if
(
new_shape
[
i
]
==
static_cast
<
size_t
>
(
-
1
))
{
ASSERT
(
infer_index
==
static_cast
<
size_t
>
(
-
1
));
// Only one -1 allowed
infer_index
=
i
;
}
else
{
known_elements
*=
new_shape
[
i
];
}
}
// Verify the new strides are compatible with the old memory layout
size_t
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
old_shape
.
size
();
++
i
)
{
offset
+=
(
old_shape
[
i
]
-
1
)
*
old_strides
[
i
];
if
(
infer_index
!=
static_cast
<
size_t
>
(
-
1
))
{
ASSERT
(
numel
%
known_elements
==
0
);
inferred_shape
[
infer_index
]
=
numel
/
known_elements
;
}
// For contiguous tensors, compute standard row-major strides
if
(
this
->
isContigous
())
{
std
::
vector
<
ptrdiff_t
>
new_strides
(
inferred_shape
.
size
());
if
(
!
inferred_shape
.
empty
())
{
new_strides
.
back
()
=
1
;
for
(
int
i
=
static_cast
<
int
>
(
inferred_shape
.
size
())
-
2
;
i
>=
0
;
--
i
)
{
new_strides
[
i
]
=
new_strides
[
i
+
1
]
*
static_cast
<
ptrdiff_t
>
(
inferred_shape
[
i
+
1
]);
}
}
return
this
->
view_as
(
inferred_shape
,
new_strides
);
}
// For non-contiguous tensors
std
::
vector
<
size_t
>
old_shape
=
shape
();
std
::
vector
<
ptrdiff_t
>
old_strides
=
strides
();
std
::
vector
<
ptrdiff_t
>
new_strides
(
inferred_shape
.
size
(),
0
);
size_t
old_idx
=
old_shape
.
size
()
-
1
;
size_t
new_idx
=
inferred_shape
.
size
()
-
1
;
size_t
new_offset
=
0
;
for
(
size_t
i
=
0
;
i
<
new_shape
.
size
();
++
i
)
{
new_offset
+=
(
new_shape
[
i
]
-
1
)
*
new_strides
[
i
];
if
(
new_idx
!=
static_cast
<
size_t
>
(
-
1
))
{
new_strides
[
new_idx
]
=
1
;
}
while
(
old_idx
!=
static_cast
<
size_t
>
(
-
1
)
&&
new_idx
!=
static_cast
<
size_t
>
(
-
1
))
{
size_t
old_size
=
old_shape
[
old_idx
];
size_t
new_size
=
inferred_shape
[
new_idx
];
if
(
old_size
==
1
)
{
old_idx
--
;
}
else
if
(
new_size
==
1
)
{
new_strides
[
new_idx
]
=
(
new_idx
==
inferred_shape
.
size
()
-
1
)
?
1
:
new_strides
[
new_idx
+
1
];
new_idx
--
;
}
else
if
(
old_size
==
new_size
)
{
new_strides
[
new_idx
]
=
old_strides
[
old_idx
];
old_idx
--
;
new_idx
--
;
}
else
if
(
old_size
<
new_size
)
{
size_t
combined_size
=
old_size
;
ptrdiff_t
combined_stride
=
old_strides
[
old_idx
];
old_idx
--
;
while
(
old_idx
!=
static_cast
<
size_t
>
(
-
1
)
&&
combined_size
<
new_size
)
{
ASSERT
(
static_cast
<
size_t
>
(
old_strides
[
old_idx
])
==
old_shape
[
old_idx
+
1
]
*
static_cast
<
size_t
>
(
old_strides
[
old_idx
+
1
]));
combined_size
*=
old_shape
[
old_idx
];
combined_stride
=
old_strides
[
old_idx
];
old_idx
--
;
}
ASSERT
(
combined_size
==
new_size
);
new_strides
[
new_idx
]
=
combined_stride
;
new_idx
--
;
}
else
{
size_t
remaining_size
=
old_size
/
new_size
;
ASSERT
(
old_size
%
new_size
==
0
);
new_strides
[
new_idx
]
=
old_strides
[
old_idx
]
*
static_cast
<
ptrdiff_t
>
(
remaining_size
);
new_idx
--
;
if
(
remaining_size
!=
1
)
{
if
(
new_idx
!=
static_cast
<
size_t
>
(
-
1
))
{
inferred_shape
[
new_idx
]
=
remaining_size
;
new_strides
[
new_idx
]
=
old_strides
[
old_idx
];
new_idx
--
;
}
else
{
ASSERT
(
false
);
}
}
old_idx
--
;
}
}
ASSERT_EQ
(
offset
,
new_offset
);
// Fill remaining dimensions (must be size 1)
while
(
new_idx
!=
static_cast
<
size_t
>
(
-
1
))
{
ASSERT
(
inferred_shape
[
new_idx
]
==
1
);
new_strides
[
new_idx
]
=
new_strides
[
new_idx
+
1
];
new_idx
--
;
}
// Create and return the reshaped tensor
auto
result
=
std
::
make_shared
<
Tensor
>
();
result
->
_storage
=
this
->
_storage
;
result
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
new_shape
,
new_strides
);
result
->
_offset
=
this
->
_offset
;
return
result
;
return
this
->
view_as
(
inferred_shape
,
new_strides
);
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view_as
(
const
std
::
vector
<
size_t
>
&
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
&
new_strides
)
const
{
...
...
@@ -332,6 +391,14 @@ std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, co
return
tensor
;
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view_as
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
{
std
::
shared_ptr
<
Tensor
>
tensor
=
std
::
make_shared
<
Tensor
>
();
tensor
->
_storage
=
this
->
_storage
;
tensor
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
new_shape
);
tensor
->
_offset
=
this
->
_offset
;
return
tensor
;
}
void
Tensor
::
debug
(
const
std
::
string
&
filename
)
const
{
RUN_INFINI
(
infinirtDeviceSynchronize
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment