Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
434961c6
Unverified
Commit
434961c6
authored
Nov 23, 2023
by
Li Zhang
Committed by
GitHub
Nov 23, 2023
Browse files
Fix cache/output length calculation (#738)
parent
6b00f623
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
13 deletions
+13
-13
src/turbomind/models/llama/LlamaBatch.cc
src/turbomind/models/llama/LlamaBatch.cc
+13
-13
No files found.
src/turbomind/models/llama/LlamaBatch.cc
View file @
434961c6
...
@@ -207,7 +207,6 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
...
@@ -207,7 +207,6 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
auto
&
seq
=
*
state
.
sequences
[
idx
];
auto
&
seq
=
*
state
.
sequences
[
idx
];
if
(
int
step
=
r
->
inputs
[
rank_
].
getVal
<
int
>
(
"step"
,
-
1
);
step
>=
0
)
{
if
(
int
step
=
r
->
inputs
[
rank_
].
getVal
<
int
>
(
"step"
,
-
1
);
step
>=
0
)
{
/// TODO: revise step setting
if
(
step
<=
seq
.
tokens
.
size
())
{
if
(
step
<=
seq
.
tokens
.
size
())
{
seq
.
tokens
.
resize
(
step
);
seq
.
tokens
.
resize
(
step
);
seq
.
cache_len
=
std
::
min
(
seq
.
cache_len
,
step
);
seq
.
cache_len
=
std
::
min
(
seq
.
cache_len
,
step
);
...
@@ -1258,7 +1257,17 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
...
@@ -1258,7 +1257,17 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
check_cuda_error
(
cudaStreamSynchronize
(
stream_
));
check_cuda_error
(
cudaStreamSynchronize
(
stream_
));
// invariant: context_length = sequence_length + 1
// `SequenceManager` needs real-time value of cache length
// ! Must be done before incrementing `h_context_length` because the generated token is NOT kv-cached yet
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
state_
->
requests
[
i
])
{
FT_CHECK
(
state_
->
sequences
[
i
]);
state_
->
sequences
[
i
]
->
cache_len
=
state_
->
h_context_length
[
i
];
}
}
// invariant: context_length = sequence_length + 1, so that h_context_length include all (including the one just
// generated) tokens
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
state_
->
h_context_length
[
i
];
++
state_
->
h_context_length
[
i
];
}
}
...
@@ -1267,7 +1276,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
...
@@ -1267,7 +1276,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
int
*
output_ptr
=
h_output_ids_
;
int
*
output_ptr
=
h_output_ids_
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
state_
->
requests
[
i
]
&&
(
state_
->
requests
[
i
]
->
stream_cb
||
state_
->
h_finished
[
i
]))
{
if
(
state_
->
requests
[
i
]
&&
(
state_
->
requests
[
i
]
->
stream_cb
||
state_
->
h_finished
[
i
]))
{
const
int
count
=
state_
->
h_context_length
[
i
]
-
1
+
int
(
g
.
step
!=
g
.
max_init_ctx_len
)
;
const
int
count
=
state_
->
h_context_length
[
i
];
// TODO: sync history output tokens at when receiving the request and copy only the last token here
// TODO: sync history output tokens at when receiving the request and copy only the last token here
std
::
copy
(
output_ptr
,
output_ptr
+
count
,
h_request_output_ids_ptrs_
[
i
]);
std
::
copy
(
output_ptr
,
output_ptr
+
count
,
h_request_output_ids_ptrs_
[
i
]);
*
h_request_seqlen_ptrs_
[
i
]
=
count
;
*
h_request_seqlen_ptrs_
[
i
]
=
count
;
...
@@ -1284,14 +1293,6 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
...
@@ -1284,14 +1293,6 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
TM_LOG_INFO
(
"[finish] [%s]"
,
ss
.
str
().
c_str
());
TM_LOG_INFO
(
"[finish] [%s]"
,
ss
.
str
().
c_str
());
}
}
// `SequenceManager` needs real-time value of cache length
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
state_
->
requests
[
i
])
{
FT_CHECK
(
state_
->
sequences
[
i
]);
state_
->
sequences
[
i
]
->
cache_len
=
state_
->
h_context_length
[
i
];
}
}
std
::
vector
<
Signal
>
signals
;
std
::
vector
<
Signal
>
signals
;
{
{
NvtxScope
_
(
"stream_and_completion_signal"
);
NvtxScope
_
(
"stream_and_completion_signal"
);
...
@@ -1343,8 +1344,7 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig
...
@@ -1343,8 +1344,7 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig
FT_CHECK
(
sequence_manager_
->
Erase
(
state_
->
requests
[
index
]
->
id
));
FT_CHECK
(
sequence_manager_
->
Erase
(
state_
->
requests
[
index
]
->
id
));
}
}
else
{
else
{
// Account for the last generated token if not a stop request (which doesn't generate)
const
int
output_len
=
state_
->
h_context_length
[
index
];
const
int
output_len
=
state_
->
h_context_length
[
index
]
+
1
-
static_cast
<
int
>
(
force_stop
);
auto
&
seq
=
*
state_
->
sequences
[
index
];
auto
&
seq
=
*
state_
->
sequences
[
index
];
// Update token IDs
// Update token IDs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment