Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
a54e3e09
Unverified
Commit
a54e3e09
authored
Sep 26, 2023
by
akhoroshev
Committed by
GitHub
Sep 26, 2023
Browse files
fix race condition (#460)
parent
327deaee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
0 deletions
+14
-0
src/turbomind/models/llama/LlamaBatch.cc
src/turbomind/models/llama/LlamaBatch.cc
+14
-0
No files found.
src/turbomind/models/llama/LlamaBatch.cc
View file @
a54e3e09
...
@@ -30,6 +30,9 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
...
@@ -30,6 +30,9 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
auto
invalidate
=
[](
const
char
*
type
,
std
::
shared_ptr
<
Request
>&
req
,
int
ec
)
{
auto
invalidate
=
[](
const
char
*
type
,
std
::
shared_ptr
<
Request
>&
req
,
int
ec
)
{
TM_LOG_WARNING
(
"[verifyRequests] Skipping invalid %s request for id %ld, code = %d"
,
type
,
(
long
)
req
->
id
,
ec
);
TM_LOG_WARNING
(
"[verifyRequests] Skipping invalid %s request for id %ld, code = %d"
,
type
,
(
long
)
req
->
id
,
ec
);
// We don't need a barrier there because
// this lambda is called only for new requests
// which are visible only for rank = 0 thread.
req
->
signal
.
set_value
(
ec
);
req
->
signal
.
set_value
(
ec
);
req
.
reset
();
req
.
reset
();
};
};
...
@@ -139,6 +142,12 @@ void LlamaBatch<T>::handleStopRequests(const std::vector<std::shared_ptr<Request
...
@@ -139,6 +142,12 @@ void LlamaBatch<T>::handleStopRequests(const std::vector<std::shared_ptr<Request
check_cuda_error
(
cudaMemsetAsync
(
sequence_length
.
getPtr
<
int
>
(),
0
,
sizeof
(
int
),
stream_
));
check_cuda_error
(
cudaMemsetAsync
(
sequence_length
.
getPtr
<
int
>
(),
0
,
sizeof
(
int
),
stream_
));
check_cuda_error
(
cudaStreamSynchronize
(
stream_
));
check_cuda_error
(
cudaStreamSynchronize
(
stream_
));
}
}
// When the signal is set threads from LlamaV2::forward can exit
// and free inputs/outputs tensors.
// Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
// are accessing the tensors.
llama_
->
shared_state_
->
barrier
->
wait
();
if
(
rank_
==
0
)
{
if
(
rank_
==
0
)
{
r
->
signal
.
set_value
(
ec
);
r
->
signal
.
set_value
(
ec
);
}
}
...
@@ -1112,6 +1121,11 @@ void LlamaBatch<T>::finishRequest(int index, bool force_end)
...
@@ -1112,6 +1121,11 @@ void LlamaBatch<T>::finishRequest(int index, bool force_end)
llama_
->
kv_cache_mgr_
->
update
(
cached_seq_
[
index
],
stream_
);
llama_
->
kv_cache_mgr_
->
update
(
cached_seq_
[
index
],
stream_
);
}
}
// When the signal is set threads from LlamaV2::forward can exit
// and free inputs/outputs tensors.
// Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
// are accessing the tensors.
llama_
->
shared_state_
->
barrier
->
wait
();
if
(
rank_
==
0
)
{
if
(
rank_
==
0
)
{
requests_
[
index
]
->
signal
.
set_value
(
0
);
requests_
[
index
]
->
signal
.
set_value
(
0
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment