Unverified Commit 911c0a85 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Optimize for throughput (#701)



* tmp

* update

* update

* optimize for throughput

* update

* fix eos

* clean up

* fix serving

* fix indexed copy

* minor

* minor

---------
Co-authored-by: default avatarlvhan028 <lvhan_028@163.com>
parent 65d735ba
...@@ -448,8 +448,8 @@ int main(int argc, char* argv[]) ...@@ -448,8 +448,8 @@ int main(int argc, char* argv[])
std::vector<int> hBuf(outCount); std::vector<int> hBuf(outCount);
ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount); ft::cudaAutoCpy(hBuf.data(), d_output_ids, outCount);
ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size); ft::cudaAutoCpy(seq_lens.data(), d_seq_lens, batch_size);
std::cout << "sequence length: "; std::cout << "sequence length: ";
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
......
...@@ -350,7 +350,7 @@ class TurboMindInstance: ...@@ -350,7 +350,7 @@ class TurboMindInstance:
outputs = _tm_dict_to_torch_dict(tm_outputs) outputs = _tm_dict_to_torch_dict(tm_outputs)
output_ids = outputs['output_ids'][:, 0, :] output_ids = outputs['output_ids'][:, 0, :]
sequence_length = outputs['sequence_length'].long()[:, 0].cpu() sequence_length = outputs['sequence_length'].long()[:, 0]
output_ids = [ output_ids = [
output_id[s:l] for output_id, s, l in zip( output_id[s:l] for output_id, s, l in zip(
output_ids, seq_start, sequence_length) output_ids, seq_start, sequence_length)
...@@ -366,7 +366,6 @@ class TurboMindInstance: ...@@ -366,7 +366,6 @@ class TurboMindInstance:
outputs.append((output[:-1], len_)) outputs.append((output[:-1], len_))
else: else:
outputs.append((output, len_)) outputs.append((output, len_))
yield outputs yield outputs
if finish: if finish:
......
...@@ -236,15 +236,88 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q) ...@@ -236,15 +236,88 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
template<typename T, int N> template<typename T, int N>
struct Array { struct Array {
T a[N];
__device__ __host__ constexpr T& operator[](int i) noexcept using value_type = T;
using size_type = int;
using difference_type = int;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using iterator = pointer;
using const_iterator = const_pointer;
static_assert(N > 0);
T __a[N];
__device__ __host__ constexpr reference operator[](size_type i) noexcept
{
return __a[i];
}
__device__ __host__ constexpr const_reference operator[](size_type i) const noexcept
{
return __a[i];
}
__device__ __host__ constexpr reference front() noexcept
{
return *begin();
}
__device__ __host__ constexpr const_reference front() const noexcept
{
return *begin();
}
__device__ __host__ constexpr reference back() noexcept
{
return *(end() - 1);
}
__device__ __host__ constexpr const_reference back() const noexcept
{
return *(end() - 1);
}
__device__ __host__ constexpr pointer data() noexcept
{ {
return a[i]; return &__a[0];
} }
__device__ __host__ constexpr const T& operator[](int i) const noexcept
__device__ __host__ constexpr const_pointer data() const noexcept
{
return &__a[0];
}
__device__ __host__ constexpr iterator begin() noexcept
{
return data();
}
__device__ __host__ constexpr const_iterator begin() const noexcept
{
return data();
}
__device__ __host__ constexpr iterator end() noexcept
{
return data() + N;
}
__device__ __host__ constexpr const_iterator end() const noexcept
{
return data() + N;
}
__device__ __host__ constexpr std::integral_constant<int, N> size() const noexcept
{
return {};
}
__device__ __host__ constexpr std::false_type empty() const noexcept
{ {
return a[i]; return {};
} }
}; };
......
...@@ -188,6 +188,7 @@ void DynamicDecodeLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_ ...@@ -188,6 +188,7 @@ void DynamicDecodeLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_
* *
* output_tensors: * output_tensors:
* \param output_ids [max_seq_len, batch_size] * \param output_ids [max_seq_len, batch_size]
* \param curand_state [local_batch_size]
* \param finished [batch_size * beam_width], optional * \param finished [batch_size * beam_width], optional
* \param should_stop [1] on cpu * \param should_stop [1] on cpu
* \param cum_log_probs [batch_size * beam_width], necessary in beam search * \param cum_log_probs [batch_size * beam_width], necessary in beam search
...@@ -276,7 +277,8 @@ void DynamicDecodeLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_ ...@@ -276,7 +277,8 @@ void DynamicDecodeLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_
{"input_lengths", input_lengths.slice({local_batch_size, beam_width}, local_batch_offset)}); {"input_lengths", input_lengths.slice({local_batch_size, beam_width}, local_batch_offset)});
} }
TensorMap decode_output_tensors({{"output_ids", output_tensors->at("output_ids")}}); TensorMap decode_output_tensors({{"output_ids", output_tensors->at("output_ids")}, //
{"curand_state", output_tensors->at("curand_state")}});
if (output_tensors->isExist("sequence_length")) { if (output_tensors->isExist("sequence_length")) {
Tensor sequence_length = output_tensors->at("sequence_length"); Tensor sequence_length = output_tensors->at("sequence_length");
decode_output_tensors.insert( decode_output_tensors.insert(
......
...@@ -53,15 +53,6 @@ protected: ...@@ -53,15 +53,6 @@ protected:
int* h_pinned_finished_sum_ = nullptr; int* h_pinned_finished_sum_ = nullptr;
public: public:
curandState_t* topk_curandstate_buf()
{
return static_cast<BaseSamplingLayer<T>*>(topk_decode_)->curandstate_buf();
}
curandState_t* topp_curandstate_buf()
{
return static_cast<BaseSamplingLayer<T>*>(topp_decode_)->curandstate_buf();
}
DynamicDecodeLayer(size_t vocab_size, DynamicDecodeLayer(size_t vocab_size,
size_t vocab_size_padded, size_t vocab_size_padded,
int end_id, int end_id,
......
...@@ -30,10 +30,6 @@ template<typename T> ...@@ -30,10 +30,6 @@ template<typename T>
void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p) void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p)
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
curandstate_buf_ = reinterpret_cast<curandState_t*>(
allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, false));
random_seeds_buf_ = reinterpret_cast<unsigned long long*>(
allocator_->reMalloc(random_seeds_buf_, sizeof(unsigned long long) * batch_size, false));
temperature_buf_ = temperature_buf_ =
reinterpret_cast<float*>(allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false)); reinterpret_cast<float*>(allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false));
repetition_penalty_buf_ = repetition_penalty_buf_ =
...@@ -58,8 +54,6 @@ void BaseSamplingLayer<T>::freeBuffer() ...@@ -58,8 +54,6 @@ void BaseSamplingLayer<T>::freeBuffer()
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) { if (is_allocate_buffer_) {
allocator_->free((void**)(&curandstate_buf_));
allocator_->free((void**)(&random_seeds_buf_));
allocator_->free((void**)(&temperature_buf_)); allocator_->free((void**)(&temperature_buf_));
allocator_->free((void**)(&repetition_penalty_buf_)); allocator_->free((void**)(&repetition_penalty_buf_));
allocator_->free((void**)(&min_lengths_buf_)); allocator_->free((void**)(&min_lengths_buf_));
...@@ -128,32 +122,6 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt ...@@ -128,32 +122,6 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
Tensor runtime_top_p = runtime_args->isExist("runtime_top_p") ? runtime_args->at("runtime_top_p") : Tensor(); Tensor runtime_top_p = runtime_args->isExist("runtime_top_p") ? runtime_args->at("runtime_top_p") : Tensor();
allocateBuffer(batch_size, runtime_top_k, runtime_top_p); allocateBuffer(batch_size, runtime_top_k, runtime_top_p);
// If runtime argument has single random seed, using this random seed to initialize the random table of all
// sentences. If the argument has [batch_size] random seeds, initializing the random table by different random seeds
// respectively. If no random seed, initialize the random table of all sentences by 0 directly.
if (runtime_args->isExist("random_seed")) {
Tensor random_seeds = runtime_args->at("random_seed");
FT_CHECK_WITH_INFO(random_seeds.shape.size() == 1
&& (random_seeds.size() == 1 || random_seeds.size() == batch_size),
fmtstr("random_seeds must be of shape [1] or [batch_size(%ld)], got random_seeds.shape=%s",
batch_size,
vec2str(random_seeds.shape).c_str()));
if (random_seeds.size() == 1) {
invokeCurandInitialize(curandstate_buf_, batch_size, random_seeds.getVal<unsigned long long>(), stream_);
sync_check_cuda_error();
}
else {
unsigned long long* random_seed_ptr = random_seeds.getPtr<unsigned long long>();
cudaAutoCpy(random_seeds_buf_, random_seed_ptr, batch_size, stream_);
invokeCurandBatchInitialize(curandstate_buf_, batch_size, random_seeds_buf_, stream_);
sync_check_cuda_error();
}
}
else {
// Initialize curand states using the default seed 0.
invokeCurandInitialize(curandstate_buf_, batch_size, 0, stream_);
}
// Setup penalties. // Setup penalties.
const float default_temperature = 1.0f; const float default_temperature = 1.0f;
Tensor temperature = runtime_args->isExist("temperature") ? Tensor temperature = runtime_args->isExist("temperature") ?
......
...@@ -33,10 +33,8 @@ protected: ...@@ -33,10 +33,8 @@ protected:
size_t vocab_size_; size_t vocab_size_;
size_t vocab_size_padded_; size_t vocab_size_padded_;
size_t sampling_workspace_size_; size_t sampling_workspace_size_;
void* sampling_workspace_ = nullptr; void* sampling_workspace_ = nullptr;
curandState_t* curandstate_buf_ = nullptr;
unsigned long long* random_seeds_buf_ = nullptr;
float* temperature_buf_ = nullptr; float* temperature_buf_ = nullptr;
float* repetition_penalty_buf_ = nullptr; float* repetition_penalty_buf_ = nullptr;
...@@ -59,11 +57,6 @@ protected: ...@@ -59,11 +57,6 @@ protected:
virtual void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p); virtual void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p);
public: public:
curandState_t* curandstate_buf()
{
return curandstate_buf_;
}
BaseSamplingLayer(size_t max_batch_size, BaseSamplingLayer(size_t max_batch_size,
size_t vocab_size, size_t vocab_size,
size_t vocab_size_padded, size_t vocab_size_padded,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
*/ */
#include <float.h> #include <float.h>
#include <sstream>
#include "src/turbomind/kernels/sampling_topk_kernels.h" #include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/kernels/sampling_topp_kernels.h" #include "src/turbomind/kernels/sampling_topp_kernels.h"
...@@ -199,6 +200,7 @@ void TopKSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp ...@@ -199,6 +200,7 @@ void TopKSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
// output_tensors: // output_tensors:
// output_ids [max_seq_len, batch_size] // output_ids [max_seq_len, batch_size]
// curand_state [local_batch_size]
// finished [local_batch_size], optional // finished [local_batch_size], optional
// sequence_length [local_batch_size], optional // sequence_length [local_batch_size], optional
// cum_log_probs [batch_size], must be float*, optional // cum_log_probs [batch_size], must be float*, optional
...@@ -255,7 +257,7 @@ void TopKSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp ...@@ -255,7 +257,7 @@ void TopKSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
output_tensors->at("finished", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr<bool>(), output_tensors->at("finished", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr<bool>(),
cum_log_probs, cum_log_probs,
output_log_probs, output_log_probs,
curandstate_buf_ + ite * local_batch_size, output_tensors->at("curand_state").getPtr<curandState_t>() + ite * local_batch_size,
(int)runtime_max_top_k_, // useless because runtime_top_k_buf_ is never nullptr. Keep for legacy. (int)runtime_max_top_k_, // useless because runtime_top_k_buf_ is never nullptr. Keep for legacy.
(int*)(runtime_top_k_buf_ + ite * local_batch_size), (int*)(runtime_top_k_buf_ + ite * local_batch_size),
1.0f, // useless because runtime_top_p_buf_ is never nullptr. Keep for legacy. 1.0f, // useless because runtime_top_p_buf_ is never nullptr. Keep for legacy.
......
...@@ -40,8 +40,6 @@ private: ...@@ -40,8 +40,6 @@ private:
using BaseSamplingLayer<T>::sampling_workspace_size_; using BaseSamplingLayer<T>::sampling_workspace_size_;
using BaseSamplingLayer<T>::sampling_workspace_; using BaseSamplingLayer<T>::sampling_workspace_;
using BaseSamplingLayer<T>::curandstate_buf_;
using BaseSamplingLayer<T>::random_seeds_buf_;
using BaseSamplingLayer<T>::skip_decode_buf_; using BaseSamplingLayer<T>::skip_decode_buf_;
using BaseSamplingLayer<T>::skip_decode_; using BaseSamplingLayer<T>::skip_decode_;
using BaseSamplingLayer<T>::skip_any_; using BaseSamplingLayer<T>::skip_any_;
......
...@@ -132,7 +132,7 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso ...@@ -132,7 +132,7 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
topp_id_vals_buf_, topp_id_vals_buf_,
topp_offset_buf_, topp_offset_buf_,
begin_topp_offset_buf_, begin_topp_offset_buf_,
curandstate_buf_, nullptr, // not used when workspace is null
batch_size, batch_size,
vocab_size_padded_, vocab_size_padded_,
nullptr, nullptr,
...@@ -267,6 +267,7 @@ void TopPSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp ...@@ -267,6 +267,7 @@ void TopPSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
* output_tensors: * output_tensors:
* \param output_ids [max_seq_len, batch_size] * \param output_ids [max_seq_len, batch_size]
* \param curand_state [local_batch_size]
* \param finished [local_batch_size], optional * \param finished [local_batch_size], optional
* \param sequence_length [local_batch_size], optional * \param sequence_length [local_batch_size], optional
* \param cum_log_probs [batch_size], must be float*, optional * \param cum_log_probs [batch_size], must be float*, optional
...@@ -319,7 +320,7 @@ void TopPSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp ...@@ -319,7 +320,7 @@ void TopPSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
topp_id_vals_buf_, topp_id_vals_buf_,
topp_offset_buf_, topp_offset_buf_,
begin_topp_offset_buf_, begin_topp_offset_buf_,
curandstate_buf_ + ite * local_batch_size, output_tensors->at("curand_state").getPtr<curandState_t>() + ite * local_batch_size,
local_batch_size, local_batch_size,
vocab_size_padded_, vocab_size_padded_,
input_tensors->at("end_id").getPtr<int>(), input_tensors->at("end_id").getPtr<int>(),
......
...@@ -48,8 +48,6 @@ private: ...@@ -48,8 +48,6 @@ private:
using BaseSamplingLayer<T>::sampling_workspace_size_; using BaseSamplingLayer<T>::sampling_workspace_size_;
using BaseSamplingLayer<T>::sampling_workspace_; using BaseSamplingLayer<T>::sampling_workspace_;
using BaseSamplingLayer<T>::curandstate_buf_;
using BaseSamplingLayer<T>::random_seeds_buf_;
using BaseSamplingLayer<T>::skip_decode_buf_; using BaseSamplingLayer<T>::skip_decode_buf_;
using BaseSamplingLayer<T>::skip_decode_; using BaseSamplingLayer<T>::skip_decode_;
using BaseSamplingLayer<T>::skip_any_; using BaseSamplingLayer<T>::skip_any_;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "src/turbomind/models/llama/LlamaBatch.h" #include "src/turbomind/models/llama/LlamaBatch.h"
#include "src/turbomind/kernels/decoding_kernels.h" #include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/macro.h" #include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h" #include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaV2.h" #include "src/turbomind/models/llama/LlamaV2.h"
...@@ -16,12 +17,15 @@ ...@@ -16,12 +17,15 @@
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <cstddef>
#include <cstdint> #include <cstdint>
#include <iomanip> #include <iomanip>
#include <iterator>
#include <mutex> #include <mutex>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
#include <utility>
namespace turbomind { namespace turbomind {
...@@ -142,7 +146,9 @@ void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r ...@@ -142,7 +146,9 @@ void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r
template<typename T> template<typename T>
auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector<Signal> auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector<Signal>
{ {
NvtxScope scope("stop_request");
std::vector<Signal> signals; std::vector<Signal> signals;
int count = 0;
for (const auto& r : requests) { for (const auto& r : requests) {
int ec = Request::kFail; int ec = Request::kFail;
// find matching active sequence // find matching active sequence
...@@ -150,29 +156,25 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector ...@@ -150,29 +156,25 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector
// stop & optionally erase active sequence // stop & optionally erase active sequence
if (state_->requests[i] && state_->requests[i]->id == r->id) { if (state_->requests[i] && state_->requests[i]->id == r->id) {
ec = 0; ec = 0;
CompleteRequest(i, true, r->end_flag); signals.push_back(Interrupt(i, true, r->end_flag));
state_->requests[i].reset(); ++count;
break; break;
} }
} }
// mismatch, try erase inactive sequence, in this case there is no active request to finish // mismatch, try erase inactive sequence, in this case there is no active request to interrupt
if (ec && r->end_flag) { if (ec && r->end_flag) {
ec = 0; if (sequence_manager_->Erase(r->id)) {
sequence_manager_->Erase(r->id); ec = 0;
}
} }
// clear output buffers (prevent leaking conversations) if request is successful signals.push_back([=] {
if (ec == 0) {
if (rank_ == 0) { if (rank_ == 0) {
std::unique_lock lock{output_mutex_}; r->signal.set_value(ec);
output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
} }
auto& output_ids = r->outputs[rank_].at("output_ids"); });
auto& sequence_length = r->outputs[rank_].at("sequence_length"); }
Clear(output_ids.getPtr<int>(), output_ids.shape.at(2)); if (count) {
Clear(sequence_length.getPtr<int>(), 1); check_cuda_error(cudaStreamSynchronize(stream_));
check_cuda_error(cudaStreamSynchronize(stream_));
}
signals.push_back([=] { r->signal.set_value(ec); });
} }
return signals; return signals;
} }
...@@ -180,25 +182,29 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector ...@@ -180,25 +182,29 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector
template<typename T> template<typename T>
void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
{ {
auto& state = *incoming_; NvtxScope scope("infer_request");
auto& state = *incoming_;
FT_CHECK(state.size == 0); FT_CHECK(state.size == 0);
FT_CHECK(state.active_size == 0); FT_CHECK(state.active_size == 0);
int i = 0; std::vector<int> existing_idx;
for (const auto& r : requests) {
// sanity check, incoming request in previous iter should have been moved to `state_` int idx = 0;
FT_CHECK(!state.requests[i]); for (const auto& r : requests) {
FT_CHECK(!state.requests[idx]);
TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id); if (rank_ == 0) {
TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id);
}
state.requests[i] = r; state.requests[idx] = r;
// get sequence for the request // get sequence for the request
state.sequences[i] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id); state.sequences[idx] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);
FT_CHECK(state.sequences[idx]);
auto& seq = *state.sequences[i]; auto& seq = *state.sequences[idx];
if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) { if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
/// TODO: revise step setting /// TODO: revise step setting
...@@ -216,7 +222,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) ...@@ -216,7 +222,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
const int* input_ids = r->inputs[rank_].getPtr<int>("input_ids"); const int* input_ids = r->inputs[rank_].getPtr<int>("input_ids");
// `output_ids` contains all token ids of the sequences // `output_ids` contains all token ids of the sequences
const auto output_ids_base = state.output_ids + session_len_ * i; const auto output_ids_base = state.output_ids + session_len_ * idx;
auto output_ids = output_ids_base; auto output_ids = output_ids_base;
// copy history tokens // copy history tokens
...@@ -230,21 +236,21 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) ...@@ -230,21 +236,21 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
} }
// total context length (history + input) // total context length (history + input)
state.h_context_length[i] = output_ids - output_ids_base; state.h_context_length[idx] = output_ids - output_ids_base;
state.h_finished[i] = false; state.h_finished[idx] = false;
const int request_output_len = state.requests[i]->inputs[rank_].getVal<int>("request_output_len"); const int request_output_len = state.requests[idx]->inputs[rank_].getVal<int>("request_output_len");
state.seq_len_limit[i] = state.h_context_length[i] + request_output_len; state.seq_len_limit[idx] = state.h_context_length[idx] + request_output_len;
// `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len // `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
// the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1 // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
if (state.seq_len_limit[i] >= session_len_) { if (state.seq_len_limit[idx] >= session_len_) {
state.seq_len_limit[i] = session_len_ - 1; state.seq_len_limit[idx] = session_len_ - 1;
if (rank_ == 0) { if (rank_ == 0) {
const int trunc_output_len = state.seq_len_limit[i] - state.h_context_length[i]; const int trunc_output_len = state.seq_len_limit[idx] - state.h_context_length[idx];
TM_LOG_WARNING( TM_LOG_WARNING(
"[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d", "[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
(long)seq.id, (long)seq.id,
state.h_context_length[i], state.h_context_length[idx],
request_output_len, request_output_len,
(int)session_len_, (int)session_len_,
trunc_output_len); trunc_output_len);
...@@ -260,7 +266,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) ...@@ -260,7 +266,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
} }
else if (model_->attn_params_.rope_scaling_factor >= 1.f) { // infer by `seq_len_limit` else if (model_->attn_params_.rope_scaling_factor >= 1.f) { // infer by `seq_len_limit`
scaling_factor = model_->attn_params_.rope_scaling_factor; scaling_factor = model_->attn_params_.rope_scaling_factor;
auto max_seq_len = state.seq_len_limit[i]; auto max_seq_len = state.seq_len_limit[idx];
auto max_pos_emb = model_->attn_params_.max_position_embeddings; auto max_pos_emb = model_->attn_params_.max_position_embeddings;
if (max_seq_len > max_pos_emb) { if (max_seq_len > max_pos_emb) {
scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1); scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
...@@ -277,22 +283,45 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) ...@@ -277,22 +283,45 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
seq.rope_theta); seq.rope_theta);
} }
} }
state.h_rope_theta[i] = seq.rope_theta; state.h_rope_theta[idx] = seq.rope_theta;
// recover device states if not a new sequence if (r->start_flag) {
if (!r->start_flag) { // prepare to initialize random state for new sequence
Copy((curandState_t*)seq.random_state.data() + 0, 1, (curandState_t*)state.top_k_curand_state); h_random_seed_[idx] = r->inputs[rank_].getVal<unsigned long long>("random_seed", 0);
Copy((curandState_t*)seq.random_state.data() + 1, 1, (curandState_t*)state.top_p_curand_state); }
else {
// Recover device states if not a new sequence
h_curand_state_[existing_idx.size()] = *(curandState_t*)seq.random_state.data();
existing_idx.push_back(idx);
} }
// ! SHARED STATE IS MODIFIED, BARRIER SYNCHRONIZATION REQUIRED
// assign priority based on arrival time // assign priority based on arrival time
r->priority = request_count_++; if (rank_ == 0) {
r->priority = request_count_++;
}
// increment pointer // increment pointer
i++; idx++;
} }
incoming_->size = i; state.size = idx;
// when there are new sequences
if (state.size != existing_idx.size()) {
// copy random seeds to device
Copy(h_random_seed_, state.size, d_random_seed_);
// initialize random states
invokeCurandBatchInitialize(state.curand_state, state.size, d_random_seed_, stream_);
sync_check_cuda_error();
}
if (!existing_idx.empty()) {
// copy existing curand states to device
Copy(h_curand_state_, existing_idx.size(), d_curand_state_);
// insert the states to their correct positions in the batch
IndexedCopy({}, existing_idx, std::tuple{d_curand_state_, state.curand_state, 1});
}
} }
template<typename T> template<typename T>
...@@ -375,14 +404,11 @@ bool LlamaBatch<T>::Initialize() ...@@ -375,14 +404,11 @@ bool LlamaBatch<T>::Initialize()
// Copy sequence states to back buffer // Copy sequence states to back buffer
FT_CHECK(back_->size == 0 && back_->active_size == 0); FT_CHECK(back_->size == 0 && back_->active_size == 0);
std::vector<std::tuple<BatchState*, BatchState*, int, int>> cpys;
for (const auto& i : idxs) { for (const auto& i : idxs) {
auto& s = *sequences[i]; auto& s = *sequences[i];
if (exchange) { if (exchange) {
const auto& [state, idx] = coords[i]; const auto& [state, idx] = coords[i];
// backup random states from dynamic decode layers for swap-outs
if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
SaveRandomState(*state, idx);
}
// mark swap-ins // mark swap-ins
if (status[i] != Sequence::kActive && s.status == Sequence::kActive) { if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
state->is_swap_in[idx] = 1; state->is_swap_in[idx] = 1;
...@@ -391,8 +417,9 @@ bool LlamaBatch<T>::Initialize() ...@@ -391,8 +417,9 @@ bool LlamaBatch<T>::Initialize()
if (s.status == Sequence::kActive) { if (s.status == Sequence::kActive) {
++back_->active_size; ++back_->active_size;
} }
CopyState(coords[i], {back_, back_->size++}); cpys.emplace_back(coords[i].first, back_, coords[i].second, back_->size++);
} }
CopyState(cpys);
// Swap the buffers // Swap the buffers
std::swap(state_, back_); std::swap(state_, back_);
...@@ -421,6 +448,9 @@ bool LlamaBatch<T>::Initialize() ...@@ -421,6 +448,9 @@ bool LlamaBatch<T>::Initialize()
// cumulative num of blocks // cumulative num of blocks
h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size(); h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();
FT_CHECK_WITH_INFO(h_cu_block_counts_[i + 1] <= sequence_manager_->max_block_count(),
std::to_string(h_cu_block_counts_[i + 1]));
k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) { k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data)); return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
}); });
...@@ -444,41 +474,60 @@ bool LlamaBatch<T>::Initialize() ...@@ -444,41 +474,60 @@ bool LlamaBatch<T>::Initialize()
} }
template<typename T> template<typename T>
void LlamaBatch<T>::CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst) void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc)
{ {
const auto& [src, i] = _src; std::vector<int> idxs(desc.size());
const auto& [dst, j] = _dst; std::iota(idxs.begin(), idxs.end(), 0);
FT_CHECK((bool)src->requests[i]); std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return desc[i] < desc[j]; });
FT_CHECK(!(bool)dst->requests[j]);
dst->h_context_length[j] = src->h_context_length[i]; auto get_signature = [&](int i) -> std::pair<BatchState*, BatchState*> {
dst->h_finished[j] = src->h_finished[i]; return std::make_pair(std::get<0>(desc[idxs[i]]), std::get<1>(desc[idxs[i]]));
dst->h_rope_theta[j] = src->h_rope_theta[i]; };
dst->seq_len_limit[j] = src->seq_len_limit[i];
dst->sequences[j] = src->sequences[i];
dst->is_swap_in[j] = src->is_swap_in[i];
dst->requests[j] = std::move(src->requests[i]);
Copy(src->output_ids + i * session_len_, src->h_context_length[i], dst->output_ids + j * session_len_); std::vector<int> offsets;
auto current = get_signature(0);
offsets.push_back(0);
for (int i = 0; i < idxs.size(); ++i) {
if (auto signature = get_signature(i); signature != current) {
current = signature;
offsets.push_back(i);
}
}
offsets.push_back(idxs.size());
Copy((curandState_t*)src->top_k_curand_state + i, 1, (curandState_t*)dst->top_k_curand_state + j); for (int bi = 1; bi < offsets.size(); ++bi) {
Copy((curandState_t*)src->top_p_curand_state + i, 1, (curandState_t*)dst->top_p_curand_state + j); int beg = offsets[bi - 1];
} int end = offsets[bi];
template<typename T> if (beg == end) {
void LlamaBatch<T>::SaveRandomState(BatchState& state, int idx) continue;
{ }
Copy(model_->GetTopKState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
Copy(model_->GetTopPState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
}
template<typename T> auto [s, d] = get_signature(beg);
void LlamaBatch<T>::LoadRandomState(BatchState& state, int idx)
{ std::vector<int> s_idx;
dbg(idx); std::vector<int> d_idx;
Copy((curandState_t*)state.top_k_curand_state + idx, 1, model_->GetTopKState(idx)); for (int i = beg; i < end; ++i) {
Copy((curandState_t*)state.top_p_curand_state + idx, 1, model_->GetTopPState(idx)); s_idx.push_back(std::get<2>(desc[idxs[i]]));
d_idx.push_back(std::get<3>(desc[idxs[i]]));
}
IndexedCopy(s_idx,
d_idx,
std::tuple{s->output_ids, d->output_ids, session_len_},
std::tuple{s->curand_state, d->curand_state, 1});
}
for (const auto& [s, d, si, di] : desc) {
d->h_context_length[di] = s->h_context_length[si];
d->h_finished[di] = s->h_finished[si];
d->h_rope_theta[di] = s->h_rope_theta[si];
d->seq_len_limit[di] = s->seq_len_limit[si];
d->sequences[di] = s->sequences[si];
d->is_swap_in[di] = s->is_swap_in[si];
d->requests[di] = s->requests[si];
}
} }
template<typename T> template<typename T>
...@@ -527,7 +576,6 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len) ...@@ -527,7 +576,6 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
token_ids_buf_ = (int*)allocator_->reMalloc(token_ids_buf_, sizeof(int) * batchxbeam * session_len * 2, true); token_ids_buf_ = (int*)allocator_->reMalloc(token_ids_buf_, sizeof(int) * batchxbeam * session_len * 2, true);
end_ids_buf_ = (int*)allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false);
finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false); finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false);
seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false); seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false);
...@@ -543,30 +591,45 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len) ...@@ -543,30 +591,45 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
template<typename T> template<typename T>
void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size) void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
{ {
stop_words_buf_ = d_stop_words_ = (int*)allocator_->reMalloc(d_stop_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
(int*)allocator_->reMalloc(stop_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true); d_bad_words_ = (int*)allocator_->reMalloc(d_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
bad_words_buf_ = h_stop_words_ =
(int*)allocator_->reMalloc(bad_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true); (int*)allocator_->reMalloc(h_stop_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
h_bad_words_ =
(int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true); h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true);
h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true); h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true);
h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true); h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true);
h_repetition_penalty_ = h_repetition_penalty_ =
(float*)allocator_->reMalloc(h_repetition_penalty_, sizeof(float) * max_batch_size, true, true); (float*)allocator_->reMalloc(h_repetition_penalty_, sizeof(float) * max_batch_size, true, true);
h_random_seed_ = (uint64_t*)allocator_->reMalloc(h_random_seed_, sizeof(uint64_t) * max_batch_size, true, true);
sampling_params_ = {{"stop_words_list", stop_words_buf_}, h_random_seed_ = (unsigned long long*)allocator_->reMalloc(
{"bad_words_list", bad_words_buf_}, h_random_seed_, sizeof(unsigned long long) * max_batch_size, true, true);
{"runtime_top_k", h_runtime_top_k_}, d_random_seed_ = (unsigned long long*)allocator_->reMalloc(
{"runtime_top_p", h_runtime_top_p_}, d_random_seed_, sizeof(unsigned long long) * max_batch_size, true, false);
{"temperature", h_temperature_},
{"repetition_penalty", h_repetition_penalty_}, h_curand_state_ =
{"random_seed", h_random_seed_}}; (curandState_t*)allocator_->reMalloc(h_curand_state_, sizeof(curandState_t) * max_batch_size, true, true);
d_curand_state_ =
(curandState_t*)allocator_->reMalloc(d_curand_state_, sizeof(curandState_t) * max_batch_size, true, false);
d_end_ids_buf_ = (int*)allocator_->reMalloc(d_end_ids_buf_, sizeof(int) * max_batch_size, false);
h_end_ids_buf_ = (int*)allocator_->reMalloc(h_end_ids_buf_, sizeof(int) * max_batch_size, false, true);
sampling_params_ = {
{"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
{"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_},
{"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr},
{"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr},
{"temperature", (std::byte*)h_temperature_, nullptr},
{"repetition_penalty", (std::byte*)h_repetition_penalty_, nullptr},
};
for (auto& s : states_) { for (auto& s : states_) {
s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true); s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true);
s.top_k_curand_state = allocator_->reMalloc(s.top_k_curand_state, sizeof(curandState_t) * max_batch_size, true); s.curand_state =
s.top_p_curand_state = allocator_->reMalloc(s.top_p_curand_state, sizeof(curandState_t) * max_batch_size, true); (curandState_t*)allocator_->reMalloc(s.curand_state, sizeof(curandState_t) * max_batch_size, true);
} }
const size_t max_block_count = sequence_manager_->max_block_count(); const size_t max_block_count = sequence_manager_->max_block_count();
...@@ -604,6 +667,9 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size) ...@@ -604,6 +667,9 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
(int*)allocator_->reMalloc(h_request_output_ids_lens_, sizeof(int) * max_batch_size, true, true); (int*)allocator_->reMalloc(h_request_output_ids_lens_, sizeof(int) * max_batch_size, true, true);
h_request_seqlen_ptrs_ = h_request_seqlen_ptrs_ =
(int**)allocator_->reMalloc(h_request_seqlen_ptrs_, sizeof(int*) * max_batch_size, true, true); (int**)allocator_->reMalloc(h_request_seqlen_ptrs_, sizeof(int*) * max_batch_size, true, true);
h_output_ids_ =
(int*)allocator_->reMalloc(h_output_ids_, sizeof(int) * max_batch_size * session_len_, false, true);
} }
is_allocate_persistant_buffer_ = true; is_allocate_persistant_buffer_ = true;
...@@ -648,7 +714,9 @@ void LlamaBatch<T>::FreeBuffer() ...@@ -648,7 +714,9 @@ void LlamaBatch<T>::FreeBuffer()
allocator_->free((void**)&token_ids_buf_); allocator_->free((void**)&token_ids_buf_);
allocator_->free((void**)&end_ids_buf_); allocator_->free((void**)&d_end_ids_buf_);
allocator_->free((void**)&h_end_ids_buf_, true);
allocator_->free((void**)&finished_buf_); allocator_->free((void**)&finished_buf_);
allocator_->free((void**)&seq_limit_len_); allocator_->free((void**)&seq_limit_len_);
...@@ -662,11 +730,22 @@ void LlamaBatch<T>::FreeBuffer() ...@@ -662,11 +730,22 @@ void LlamaBatch<T>::FreeBuffer()
} }
if (is_allocate_persistant_buffer_) { if (is_allocate_persistant_buffer_) {
allocator_->free((void**)&d_stop_words_);
allocator_->free((void**)&h_stop_words_, true);
allocator_->free((void**)&d_bad_words_);
allocator_->free((void**)&h_bad_words_, true);
allocator_->free((void**)&d_random_seed_);
allocator_->free((void**)&h_random_seed_, true);
allocator_->free((void**)&d_curand_state_);
allocator_->free((void**)&h_curand_state_, true);
for (auto& s : states_) { for (auto& s : states_) {
allocator_->free((void**)&s.h_context_length, true); allocator_->free((void**)&s.h_context_length, true);
allocator_->free((void**)&s.h_finished, true); allocator_->free((void**)&s.h_finished, true);
allocator_->free((void**)&s.h_rope_theta, true); allocator_->free((void**)&s.h_rope_theta, true);
allocator_->free((void**)&s.output_ids); allocator_->free((void**)&s.output_ids);
allocator_->free((void**)&s.curand_state);
} }
allocator_->free((void**)&h_tmp_k_ptrs_, true); allocator_->free((void**)&h_tmp_k_ptrs_, true);
allocator_->free((void**)&h_tmp_v_ptrs_, true); allocator_->free((void**)&h_tmp_v_ptrs_, true);
...@@ -681,6 +760,8 @@ void LlamaBatch<T>::FreeBuffer() ...@@ -681,6 +760,8 @@ void LlamaBatch<T>::FreeBuffer()
allocator_->free((void**)&h_request_output_ids_lens_, true); allocator_->free((void**)&h_request_output_ids_lens_, true);
allocator_->free((void**)&h_request_seqlen_ptrs_, true); allocator_->free((void**)&h_request_seqlen_ptrs_, true);
allocator_->free((void**)&h_output_ids_, true);
is_allocate_persistant_buffer_ = false; is_allocate_persistant_buffer_ = false;
} }
} }
...@@ -723,14 +804,15 @@ LlamaBatch<T>::LlamaBatch(int max_batch_size, ...@@ -723,14 +804,15 @@ LlamaBatch<T>::LlamaBatch(int max_batch_size,
template<typename T> template<typename T>
void LlamaBatch<T>::InitializeSampling() void LlamaBatch<T>::InitializeSampling()
{ {
NvtxScope _("InitSampling");
const int batch_size = state_->active_size; const int batch_size = state_->active_size;
TensorMap inputs; TensorMap inputs;
for (const auto& param : sampling_params_) { for (const auto& [name, h_ptr, d_ptr] : sampling_params_) {
// find an exemplar that matches the param name // find an exemplar that matches the param name
const Tensor* ptr{}; const Tensor* ptr{};
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]->inputs[rank_].isExist(param.first)) { if (state_->requests[i]->inputs[rank_].isExist(name)) {
ptr = &state_->requests[i]->inputs[rank_].at(param.first); ptr = &state_->requests[i]->inputs[rank_].at(name);
break; break;
} }
} }
...@@ -741,42 +823,50 @@ void LlamaBatch<T>::InitializeSampling() ...@@ -741,42 +823,50 @@ void LlamaBatch<T>::InitializeSampling()
FT_CHECK(shape[0] == 1); FT_CHECK(shape[0] == 1);
shape[0] = batch_size; shape[0] = batch_size;
const int size_in_bytes = ref.sizeBytes(); const int size_in_bytes = ref.sizeBytes();
Clear((std::byte*)param.second, size_in_bytes * batch_size); memset(h_ptr, 0, size_in_bytes * batch_size);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]->inputs[rank_].isExist(param.first)) { if (state_->requests[i]->inputs[rank_].isExist(name)) {
auto& src = state_->requests[i]->inputs[rank_].at(param.first); Tensor& src = state_->requests[i]->inputs[rank_].at(name);
FT_CHECK(ref.shape == src.shape); FT_CHECK(ref.shape == src.shape);
Copy(src.getPtr<std::byte>(), size_in_bytes, (std::byte*)param.second + size_in_bytes * i); std::copy_n(src.getPtr<std::byte>(), size_in_bytes, h_ptr + size_in_bytes * i);
} }
} }
inputs.insert({param.first, {ref.where, ref.type, shape, param.second}}); if (d_ptr) {
Copy(h_ptr, batch_size * size_in_bytes, d_ptr);
}
inputs.insert({name, {d_ptr ? MEMORY_GPU : MEMORY_CPU, ref.type, shape, d_ptr ? d_ptr : h_ptr}});
if (debug_ && rank_ == 0) { if (debug_ && rank_ == 0) {
TM_LOG_INFO("[initializeSampling] %s", format({param.first, inputs.at(param.first)}).c_str()); TM_LOG_INFO("[initializeSampling] %s", format({name, inputs.at(name)}).c_str());
} }
} }
} }
// init for eos
std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);
Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_);
inputs.insert({"end_id", {MEMORY_GPU, TYPE_INT32, {(size_t)batch_size}, d_end_ids_buf_}});
inputs_ = std::move(inputs); inputs_ = std::move(inputs);
model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_); model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);
// recover random states if not a new request
for (int i = 0; i < batch_size; ++i) {
if (!state_->requests[i]->start_flag && state_->is_swap_in[i]) {
LoadRandomState(*state_, i);
}
}
handleOptArg(&inputs_, "end_id", end_ids_buf_, model_->end_id_, batch_size);
cudaStreamSynchronize(0);
} }
template<typename T> template<typename T>
auto LlamaBatch<T>::InitializeGeneration() -> GenerationState auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
{ {
NvtxScope _("InitGen");
const int batch_size = state_->active_size; const int batch_size = state_->active_size;
const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size); const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
Copy(state_->h_context_length, batch_size, context_length_buf_); // also referenced in `SetOutputTensors`
Copy(context_length_buf_, batch_size, sequence_lengths_);
// `sequence_lengths_` will be increased by dynamic decode
// note that in decoder and in output "sequence length" has different semantic
// - in decoder it means length of sequence that has kv cache already computed
// - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
sync_check_cuda_error();
Clear(token_ids_buf_, batch_size * session_len_); Clear(token_ids_buf_, batch_size * session_len_);
invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_); invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
sync_check_cuda_error(); sync_check_cuda_error();
...@@ -787,22 +877,7 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState ...@@ -787,22 +877,7 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
// ABCDEFGHi -> ABCDEFGHi i // ABCDEFGHi -> ABCDEFGHi i
// ABCDEFGh ABCDEFGh h // ABCDEFGh ABCDEFGh h
// ABCd ABCd d // ABCd ABCd d
for (int i = 0; i < batch_size; ++i) { invokePadLastTokenIds(token_ids_buf_, context_length_buf_, max_context_len, batch_size, stream_);
auto token_ids = token_ids_buf_ + i;
auto p_src = state_->h_context_length[i] - 1;
auto p_dst = max_context_len - 1;
if (p_src != p_dst) { // dst and src of `cudaMemcpyAsync` must not overlap
Copy(token_ids + p_src * batch_size, 1, token_ids + p_dst * batch_size);
}
}
Copy(state_->h_context_length, batch_size, context_length_buf_); // also referenced in `SetOutputTensors`
Copy(context_length_buf_, batch_size, sequence_lengths_);
// `sequence_lengths_` will be increased by dynamic decode
// note that in decoder and in output "sequence length" has different semantic
// - in decoder it means length of sequence that has kv cache already computed
// - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
sync_check_cuda_error(); sync_check_cuda_error();
// used for dispatching split-k decoding kernels // used for dispatching split-k decoding kernels
...@@ -846,21 +921,19 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState ...@@ -846,21 +921,19 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size); TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size);
TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len); TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len);
TM_LOG_INFO("[initGen] slot sequence_id context_len seq_limit_len finished"); if (debug_) {
for (int i = 0; i < batch_size; ++i) { TM_LOG_INFO("[initGen] slot sequence_id context_len seq_limit_len finished");
TM_LOG_INFO("[initGen] %4d %11ld %11d %13d %8d", for (int i = 0; i < batch_size; ++i) {
i, TM_LOG_INFO("[initGen] %4d %11ld %11d %13d %8d",
(long)state_->sequences[i]->id, i,
state_->h_context_length[i], (long)state_->sequences[i]->id,
(int)h_seq_limit_len_[i], state_->h_context_length[i],
(int)state_->h_finished[i]); (int)h_seq_limit_len_[i],
(int)state_->h_finished[i]);
}
} }
} }
// for (int i = 0; i < batch_size; ++i) {
// gSequenceIds(i) = state_->requests[i]->id;
// }
return GenerationState{max_context_len, start_step, sum_seq_len, max_seq_len}; return GenerationState{max_context_len, start_step, sum_seq_len, max_seq_len};
} }
...@@ -908,6 +981,9 @@ bool LlamaBatch<T>::Generate(GenerationState& g) ...@@ -908,6 +981,9 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
decoder_output_buf_, decoder_output_buf_,
batch_size); batch_size);
/// sync for better NVTX visualization, THIS IS NOT NEEDED
// check_cuda_error(cudaStreamSynchronize(stream_));
// stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
// not supported yet. // not supported yet.
bool should_stop{}; bool should_stop{};
...@@ -915,12 +991,13 @@ bool LlamaBatch<T>::Generate(GenerationState& g) ...@@ -915,12 +991,13 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
finished_buf_, finished_buf_,
sequence_lengths_, sequence_lengths_,
&should_stop, &should_stop,
state_->curand_state,
&inputs_, &inputs_,
&outputs_, &outputs_,
logits_buf_, logits_buf_,
seq_limit_len_, seq_limit_len_,
context_length_buf_, context_length_buf_,
end_ids_buf_, d_end_ids_buf_,
g.step, g.step,
0, 0,
g.max_init_ctx_len, g.max_init_ctx_len,
...@@ -960,6 +1037,7 @@ bool LlamaBatch<T>::Generate(GenerationState& g) ...@@ -960,6 +1037,7 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
template<typename T> template<typename T>
void LlamaBatch<T>::ContextDecode() void LlamaBatch<T>::ContextDecode()
{ {
NvtxScope _("prefill");
const auto batch_size = state_->active_size; const auto batch_size = state_->active_size;
int base = -1; int base = -1;
...@@ -987,8 +1065,8 @@ void LlamaBatch<T>::ContextDecode() ...@@ -987,8 +1065,8 @@ void LlamaBatch<T>::ContextDecode()
Copy(state_->h_rope_theta, batch_size, rope_theta_); Copy(state_->h_rope_theta, batch_size, rope_theta_);
Copy(h_input_length_buf_, batch_size, input_length_buf_); Copy(h_input_length_buf_, batch_size, input_length_buf_);
check_cuda_error(cudaStreamSynchronize(stream_)); // check_cuda_error(cudaStreamSynchronize(stream_));
const auto tick = std::chrono::high_resolution_clock::now(); // const auto tick = std::chrono::high_resolution_clock::now();
if (rank_ == 0) { if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count); TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
...@@ -1039,7 +1117,7 @@ void LlamaBatch<T>::ContextDecode() ...@@ -1039,7 +1117,7 @@ void LlamaBatch<T>::ContextDecode()
auto input_ids = context_decoder_ids_buf_; auto input_ids = context_decoder_ids_buf_;
TM_LOG_INFO("first = %d, last = %d", first, last); TM_LOG_INFO("first = %d, last = %d", first, last);
for (int i = first; i < last; ++i) { for (int i = first; i < last; ++i) {
TM_LOG_INFO("session_len = %d, input_length = %d", session_len_, h_input_length_buf_[i]); // TM_LOG_INFO("session_len = %d, input_length = %d", session_len_, h_input_length_buf_[i]);
input_ids = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids); input_ids = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
dbg(i, h_input_length_buf_[i]); dbg(i, h_input_length_buf_[i]);
h_tmp_k_ptrs_[i] = k_ptr; h_tmp_k_ptrs_[i] = k_ptr;
...@@ -1069,15 +1147,6 @@ void LlamaBatch<T>::ContextDecode() ...@@ -1069,15 +1147,6 @@ void LlamaBatch<T>::ContextDecode()
dbg(first, last); dbg(first, last);
dbg(k_block_ptrs_, v_block_ptrs_); dbg(k_block_ptrs_, v_block_ptrs_);
if (1) {
std::vector<int> input_len(sub_batch_size);
std::vector<int> context_len(sub_batch_size);
Copy(input_length_buf_ + first, sub_batch_size, input_len.data());
Copy(context_length_buf_ + first, sub_batch_size, context_len.data());
cudaStreamSynchronize(stream_);
dbg(input_len, context_len);
}
model_->contextDecode(nullptr, model_->contextDecode(nullptr,
k_block_ptrs_, k_block_ptrs_,
v_block_ptrs_, v_block_ptrs_,
...@@ -1112,11 +1181,11 @@ void LlamaBatch<T>::ContextDecode() ...@@ -1112,11 +1181,11 @@ void LlamaBatch<T>::ContextDecode()
} }
} }
check_cuda_error(cudaStreamSynchronize(stream_)); // check_cuda_error(cudaStreamSynchronize(stream_));
const auto tock = std::chrono::high_resolution_clock::now(); // const auto tock = std::chrono::high_resolution_clock::now();
if (rank_ == 0) { // if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count()); // TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count());
} // }
} }
template<typename T> template<typename T>
...@@ -1167,65 +1236,43 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_ ...@@ -1167,65 +1236,43 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_
} }
template<typename T> template<typename T>
auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal> auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vector<Signal>
{ {
NvtxScope scope("Finish"); NvtxScope scope("Finish");
const int batch_size = state_->active_size; const int batch_size = state_->active_size;
// secure info needed by `Initialize()` // [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
Copy(finished_buf_, batch_size, state_->h_finished); invokeGatherOutput(state_->output_ids,
token_ids_buf_,
context_length_buf_,
g.max_init_ctx_len,
g.step,
session_len_,
batch_size,
stream_);
sync_check_cuda_error();
// invariant: context_length = sequence_length + 1 Copy(state_->output_ids, batch_size * session_len_, h_output_ids_);
invokePlusScalar(sequence_lengths_, 1, batch_size, stream_); Copy(finished_buf_, batch_size, state_->h_finished);
Copy(sequence_lengths_, batch_size, state_->h_context_length); Copy(sequence_lengths_, batch_size, state_->h_context_length);
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
if constexpr (0) {
std::unique_lock<std::mutex> lock;
if (rank_ == 0) {
NvtxScope _("acquire_outputs");
// wait for previous output operations
lock = std::unique_lock{output_mutex_};
output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
}
SetOutputTensors(g); check_cuda_error(cudaStreamSynchronize(stream_));
check_cuda_error(cudaStreamSynchronize(stream_));
if (rank_ == 0) { // invariant: context_length = sequence_length + 1
NvtxScope _("signal_output_thread"); for (int i = 0; i < batch_size; ++i) {
// enqueue new output requests ++state_->h_context_length[i];
for (int i = 0; i < batch_size; ++i) {
FT_CHECK(state_->requests[i] != nullptr);
if (state_->requests[i]->stream_cb) {
output_reqs_.push_back(state_->requests[i]);
}
}
lock.unlock();
// notify output thread when we do have stream cbs to call
if (!output_reqs_.empty()) {
output_cv_.notify_one();
}
}
} }
else {
SetOutputTensors(g);
check_cuda_error(cudaStreamSynchronize(stream_));
{ { // set output tokens ids and sequence length
NvtxScope _("output_cb"); int* output_ptr = h_output_ids_;
if (rank_ == 0 && model_->ffi_lock_) { for (int i = 0; i < batch_size; ++i) {
model_->ffi_lock_(1); if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
} const int count = state_->h_context_length[i] - 1 + int(g.step != g.max_init_ctx_len);
for (int i = 0; i < batch_size; ++i) { // TODO: sync history output tokens at when receiving the request and copy only the last token here
FT_CHECK(state_->requests[i] != nullptr); std::copy(output_ptr, output_ptr + count, h_request_output_ids_ptrs_[i]);
if (state_->requests[i]->stream_cb && rank_ == 0) { *h_request_seqlen_ptrs_[i] = count;
state_->requests[i]->stream_cb(&state_->requests[i]->outputs[rank_].get());
}
}
if (rank_ == 0 && model_->ffi_lock_) {
model_->ffi_lock_(0);
} }
output_ptr += session_len_;
} }
} }
...@@ -1247,66 +1294,37 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal> ...@@ -1247,66 +1294,37 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
std::vector<Signal> signals; std::vector<Signal> signals;
{ {
NvtxScope _("prepare_completion_signal"); NvtxScope _("stream_and_completion_signal");
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i] && state_->h_finished[i]) { if (state_->requests[i]) {
CompleteRequest(i, false, false); if (state_->h_finished[i]) {
signals.push_back([r = std::move(state_->requests[i])] { r->signal.set_value(0); }); // Interrupt finished sequences and move the request handle into the signal closure
signals.push_back(Interrupt(i));
++finished_count;
}
else if (state_->requests[i]->stream_cb) {
// Create signals by copying the request handles for non-finished streaming requests
signals.push_back([this, r = state_->requests[i]] {
if (rank_ == 0) {
r->stream_cb(&r->outputs[rank_].get());
}
});
}
} }
} }
if (finished_count) {
// synchronize for interrupted sequences
check_cuda_error(cudaStreamSynchronize(stream_));
}
} }
return signals; return signals;
} }
template<typename T> template<typename T>
void LlamaBatch<T>::SetOutputTensors(const GenerationState& g) auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Signal
{
NvtxScope scope("SetOutputTensors");
// dbg(g.max_init_ctx_len);
const auto batch_size = state_->active_size;
// [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
invokeGatherOutput(state_->output_ids,
token_ids_buf_,
context_length_buf_,
g.max_init_ctx_len,
g.step,
session_len_,
batch_size,
stream_);
sync_check_cuda_error();
if constexpr (1) {
invokeUpdateOutput(request_output_ids_ptrs_,
request_seqlen_ptrs_,
state_->output_ids,
sequence_lengths_,
request_output_ids_lens_,
session_len_,
g.step > g.max_init_ctx_len,
batch_size,
stream_);
sync_check_cuda_error();
}
else {
// for (int i = 0; i < batch_size; ++i) {
// if (state_->requests[i]) {
// auto& output_ids = state_->requests[i]->outputs[rank_].at("output_ids");
// auto& sequence_length = state_->requests[i]->outputs[rank_].at("sequence_length");
// Copy(state_->output_ids + i * session_len_, output_ids.shape.at(2), output_ids.getPtr<int>());
// Copy(sequence_lengths_ + i, 1, sequence_length.getPtr<int>());
// if (g.step > g.max_init_ctx_len) { // +1 for newly generated token
// invokePlusScalar(sequence_length.getPtr<int>(), 1, 1, stream_);
// }
// }
// }
}
}
template<typename T>
void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_force_end)
{ {
if (rank_ == 0) { if (rank_ == 0) {
TM_LOG_INFO("[CompleteRequest] slot = %d, id = %lu", index, (long)state_->requests[index]->id); TM_LOG_INFO("[Interrupt] slot = %d, id = %lu", index, (long)state_->requests[index]->id);
} }
if (debug_ && rank_ == 0) { if (debug_ && rank_ == 0) {
...@@ -1317,45 +1335,46 @@ void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_for ...@@ -1317,45 +1335,46 @@ void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_for
for (const auto& t : tokens) { for (const auto& t : tokens) {
ss << " " << t; ss << " " << t;
} }
TM_LOG_INFO("[CompleteRequest] slot %d, tokens [%s]", index, ss.str().c_str()); TM_LOG_INFO("[Interrupt] slot %d, tokens [%s]", index, ss.str().c_str());
} }
if (state_->requests[index]->end_flag || is_force_end) { if (state_->requests[index]->end_flag || force_end) {
sequence_manager_->Erase(state_->requests[index]->id); // Sequence is ending this round or a stop request is issued to end it
FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id));
} }
else { else {
// account for the last generated token if not a stop request (which doesn't generate) // Account for the last generated token if not a stop request (which doesn't generate)
const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(is_stop_request); const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(force_stop);
auto& seq = *state_->sequences[index];
auto& seq = *state_->sequences[index]; // Update token IDs
// update token IDs
seq.tokens.resize(output_len); seq.tokens.resize(output_len);
const auto output_ids_data = state_->requests[index]->outputs[rank_].at("output_ids").getPtr<int>(); const auto output_ids_data = state_->requests[index]->outputs[rank_].at("output_ids").getPtr<int>();
Copy(output_ids_data, output_len, seq.tokens.data()); std::copy_n(output_ids_data, output_len, seq.tokens.data());
// update random states // Save random state in host memory
seq.random_state.resize(sizeof(curandState_t) * 2); seq.random_state.resize(sizeof(curandState_t));
// This async copy must be synchronized by the caller
// save random state in host memory Copy(state_->curand_state + index, 1, (curandState_t*)seq.random_state.data());
if (auto ptr = (curandState_t*)seq.random_state.data()) {
ptr = Copy(model_->GetTopKState(index), 1, ptr);
ptr = Copy(model_->GetTopPState(index), 1, ptr);
}
check_cuda_error(cudaStreamSynchronize(stream_));
// Set unlock flag for corresponding blocks, will be unlocked in the next `Materialize()`
sequence_manager_->UpdateAndSetUnlock(seq); sequence_manager_->UpdateAndSetUnlock(seq);
} }
state_->sequences[index] = nullptr; state_->sequences[index] = nullptr;
// move the request handle into the signal
return [this, r = std::move(state_->requests[index])] {
if (rank_ == 0) {
r->signal.set_value(0);
}
};
} }
template<typename T> template<typename T>
void LlamaBatch<T>::InternalThreadEntry(int device_id) void LlamaBatch<T>::InternalThreadEntry(int device_id)
{ {
TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_); // TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_);
check_cuda_error(cudaSetDevice(device_id)); check_cuda_error(cudaSetDevice(device_id));
auto& shared_state = model_->shared_state_; auto& shared_state = model_->shared_state_;
...@@ -1364,20 +1383,26 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id) ...@@ -1364,20 +1383,26 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
auto& infer_requests = shared_state->infer_requests; auto& infer_requests = shared_state->infer_requests;
auto& stop_requests = shared_state->stop_requests; auto& stop_requests = shared_state->stop_requests;
// sequences that are removed but still counted in state's size
int finished_count = 0; int finished_count = 0;
GenerationState g{}; GenerationState g{};
constexpr int request_interval = 1;
long request_counter = 0;
while (1) { while (1) {
if (rank_ == 0) { if (rank_ == 0) {
const int free_slot_count = max_batch_size_ - state_->size + finished_count; const int free_slot_count = max_batch_size_ - state_->size + finished_count;
const bool is_empty = (free_slot_count == max_batch_size_); const bool is_empty = (free_slot_count == max_batch_size_);
stop_requests.clear();
// will block if batch is empty infer_requests.clear();
request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state->abort); if (is_empty || request_counter % request_interval == 0) {
// Block if batch is empty
if (!shared_state->abort) { request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state->abort);
RejectInvalidRequests(stop_requests, infer_requests); if (!shared_state->abort) {
RejectInvalidRequests(stop_requests, infer_requests);
}
} }
} }
...@@ -1388,20 +1413,19 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id) ...@@ -1388,20 +1413,19 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
if (shared_state->abort) { if (shared_state->abort) {
TM_LOG_INFO("[InternalThreadEntry] stop requested."); TM_LOG_INFO("[InternalThreadEntry] stop requested.");
// if (state_->size && rank_ == 0) {
// TM_LOG_WARNING("Active request(s) present (%d) while exiting.", state_->size);
// }
return; return;
} }
auto signals = ProcessStopRequests(stop_requests); auto signals = ProcessStopRequests(stop_requests);
BarrierSignalRequests(*shared_state->barrier, signals);
// Shared `priority` field will be assigned by rank-0
ProcessInferRequests(infer_requests); ProcessInferRequests(infer_requests);
// wait while shared stop/infer_requests is being used // Wait while shared `requests` is being used
shared_state->barrier->wait(); shared_state->barrier->wait();
SendSignals(std::move(signals));
auto modified = Initialize(); auto modified = Initialize();
// finished sequences is handled by `Initialize()` // finished sequences is handled by `Initialize()`
finished_count = 0; finished_count = 0;
...@@ -1418,31 +1442,42 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id) ...@@ -1418,31 +1442,42 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
break; break;
} }
} }
auto signals = Finish(g); if (auto signals = Finish(g, finished_count); !signals.empty()) {
finished_count = signals.size(); if (finished_count) {
BarrierSignalRequests(*shared_state->barrier, signals); // Finished requests and corresponding output tensors will be released when notified
// wait for all ranks to ensure no rank (except for output thread) will access related
// resources
shared_state->barrier->wait();
}
SendSignals(std::move(signals));
}
} }
++request_counter;
} }
FT_CHECK(0); FT_CHECK(0);
} }
template<typename T> template<typename T>
void LlamaBatch<T>::BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals) void LlamaBatch<T>::SendSignals(std::vector<Signal> signals)
{ {
if (!signals.empty()) { if (rank_ != 0 || signals.empty()) {
barrier.wait(); return;
if (rank_ == 0) { }
std::for_each(signals.cbegin(), signals.cend(), [](auto& s) { s(); }); {
} std::lock_guard lock{output_mutex_};
barrier.wait(); output_signals_.insert(output_signals_.end(), //
std::move_iterator{signals.begin()},
std::move_iterator{signals.end()});
} }
output_cv_.notify_one();
} }
template<typename T> template<typename T>
void LlamaBatch<T>::Start() void LlamaBatch<T>::Start()
{ {
TM_LOG_ERROR("LlamaBatch<T>::Start()"); TM_LOG_INFO("LlamaBatch<T>::Start()");
int device_id = -1; int device_id = -1;
check_cuda_error(cudaGetDevice(&device_id)); check_cuda_error(cudaGetDevice(&device_id));
internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this, device_id); internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this, device_id);
...@@ -1455,36 +1490,27 @@ template<typename T> ...@@ -1455,36 +1490,27 @@ template<typename T>
void LlamaBatch<T>::OutputThreadEntry() void LlamaBatch<T>::OutputThreadEntry()
{ {
while (true) { while (true) {
std::vector<Signal> signals;
{ {
// wait for requests with stream cbs // Wait for signals to come
std::unique_lock lock(output_mutex_); std::unique_lock lock(output_mutex_);
output_cv_.wait(lock, [&] { return !output_reqs_.empty() || output_stop_token_; }); output_cv_.wait(lock, [&] { return !output_signals_.empty() || output_stop_token_; });
// NvtxScope _("output_callback");
// stop requested
if (output_stop_token_) { if (output_stop_token_) {
TM_LOG_INFO("[OutputThreadEntry] stop requested."); TM_LOG_INFO("[OutputThreadEntry] stop requested.");
return; return;
} }
signals = std::move(output_signals_);
if (rank_ == 0 && model_->ffi_lock_) { }
TM_LOG_INFO("acquire GIL"); if (rank_ == 0 && model_->ffi_lock_) {
model_->ffi_lock_(1); model_->ffi_lock_(1);
TM_LOG_INFO("acquire GIL success"); }
} // invoke stream cbs & signals
// invoke stream cbs for (const auto& s : signals) {
for (const auto& r : output_reqs_) { s();
r->stream_cb(&r->outputs[rank_].get()); }
} if (rank_ == 0 && model_->ffi_lock_) {
if (rank_ == 0 && model_->ffi_lock_) { model_->ffi_lock_(0);
TM_LOG_INFO("release GIL");
model_->ffi_lock_(0);
TM_LOG_INFO("release GIL success");
}
output_reqs_.clear();
} }
FT_CHECK(output_reqs_.empty());
// notify infer thread 0
output_cv_.notify_one();
} }
} }
......
...@@ -3,14 +3,18 @@ ...@@ -3,14 +3,18 @@
#pragma once #pragma once
// #include "src/turbomind/models/llama/LlamaCacheManager.h" // #include "src/turbomind/models/llama/LlamaCacheManager.h"
#include "src/turbomind/layers/sampling_layers/BaseSamplingLayer.h"
#include "src/turbomind/models/llama/Barrier.h" #include "src/turbomind/models/llama/Barrier.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h" #include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/Request.h" #include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h" #include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include <condition_variable> #include <condition_variable>
#include <mutex> #include <mutex>
#include <type_traits>
namespace turbomind { namespace turbomind {
...@@ -18,9 +22,8 @@ struct BatchState { ...@@ -18,9 +22,8 @@ struct BatchState {
int* h_context_length; int* h_context_length;
bool* h_finished; bool* h_finished;
void* top_k_curand_state; curandState_t* curand_state;
void* top_p_curand_state; int* output_ids; // output ids in [B, S]
int* output_ids; // output ids in [B, S]
float* h_rope_theta; float* h_rope_theta;
...@@ -66,16 +69,15 @@ public: ...@@ -66,16 +69,15 @@ public:
int max_seq_len; int max_seq_len;
}; };
void InitializeSampling(); void InitializeSampling();
GenerationState InitializeGeneration(); GenerationState InitializeGeneration();
[[nodiscard]] bool Generate(GenerationState& g); [[nodiscard]] bool Generate(GenerationState& g);
[[nodiscard]] auto Finish(GenerationState& g) -> std::vector<Signal>; [[nodiscard]] auto Finish(GenerationState& g, int& finished_count) -> std::vector<Signal>;
void CompleteRequest(int index, bool is_stop_request, bool is_force_end);
void SetOutputTensors(const GenerationState& g); [[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false);
void void
OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths); OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
...@@ -88,7 +90,7 @@ public: ...@@ -88,7 +90,7 @@ public:
~LlamaBatch() ~LlamaBatch()
{ {
TM_LOG_ERROR("~LlamaBatch()"); TM_LOG_INFO("~LlamaBatch()");
model_->shared_state_->request_queue.close(); model_->shared_state_->request_queue.close();
internal_thread_.join(); internal_thread_.join();
...@@ -112,15 +114,9 @@ private: ...@@ -112,15 +114,9 @@ private:
void OutputThreadEntry(); void OutputThreadEntry();
void UpdateSequenceStates(BatchState& state, int index); void CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc);
void CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst);
void SaveRandomState(BatchState& state, int idx);
void LoadRandomState(BatchState& state, int idx);
void BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals); void SendSignals(std::vector<Signal> signals);
// analogs to `std::copy_n` // analogs to `std::copy_n`
template<typename U> template<typename U>
...@@ -137,6 +133,47 @@ private: ...@@ -137,6 +133,47 @@ private:
return data += count; return data += count;
} }
template<class... Ts>
void IndexedCopyImpl(const int* src_idx, const int* dst_idx, int count, const std::tuple<Ts*, Ts*, int>&... cpys)
{
if (!count) {
return;
}
constexpr int N = sizeof...(Ts);
static_assert((!std::is_same_v<Ts, void> && ...));
std::array<void*, N> src_ptr{std::get<0>(cpys)...};
std::array<void*, N> dst_ptr{std::get<1>(cpys)...};
std::array<int, N> elem_sz{int(sizeof(Ts) * std::get<2>(cpys))...};
invokeIndexedCopy(src_ptr.data(), //
dst_ptr.data(),
elem_sz.data(),
src_idx,
dst_idx,
count,
N,
stream_);
sync_check_cuda_error();
}
template<class... Ts>
void IndexedCopy(const std::vector<int>& src_idx,
const std::vector<int>& dst_idx,
const std::tuple<Ts*, Ts*, int>&... cpys)
{
// has the same size, or one is empty
FT_CHECK(src_idx.size() == dst_idx.size() || (src_idx.empty() ^ dst_idx.empty()));
IndexedCopyImpl(src_idx.empty() ? nullptr : src_idx.data(),
dst_idx.empty() ? nullptr : dst_idx.data(),
std::max(src_idx.size(), dst_idx.size()),
cpys...);
}
template<class... Ts>
void IndexedCopy(int count, const std::tuple<Ts*, Ts*, int>&... cpys)
{
IndexedCopyImpl(nullptr, nullptr, count, cpys...);
}
private: private:
const int max_batch_size_; const int max_batch_size_;
const int max_context_token_num_; const int max_context_token_num_;
...@@ -186,9 +223,10 @@ private: ...@@ -186,9 +223,10 @@ private:
// used by dynamic decoder // used by dynamic decoder
int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step` int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step`
int* end_ids_buf_{};
bool* finished_buf_{}; bool* finished_buf_{};
uint32_t* seq_limit_len_{}; uint32_t* seq_limit_len_{};
int* h_end_ids_buf_{};
int* d_end_ids_buf_{};
int** request_output_ids_ptrs_{}; int** request_output_ids_ptrs_{};
int* request_output_ids_lens_{}; int* request_output_ids_lens_{};
...@@ -205,13 +243,20 @@ private: ...@@ -205,13 +243,20 @@ private:
uintptr_t* h_k_block_ptrs_{}; uintptr_t* h_k_block_ptrs_{};
uintptr_t* h_v_block_ptrs_{}; uintptr_t* h_v_block_ptrs_{};
int* stop_words_buf_{}; // [batch_size, 2, kMaxStopWordsLen] int* h_runtime_top_k_{};
int* bad_words_buf_{}; float* h_runtime_top_p_{};
int* h_runtime_top_k_{}; float* h_temperature_{};
float* h_runtime_top_p_{}; float* h_repetition_penalty_{};
float* h_temperature_{}; int* h_stop_words_{}; // [batch_size, 2, kMaxStopWordsLen]
float* h_repetition_penalty_{}; int* h_bad_words_{};
uint64_t* h_random_seed_{}; int* d_stop_words_{}; // [batch_size, 2, kMaxStopWordsLen]
int* d_bad_words_{};
unsigned long long* h_random_seed_{};
unsigned long long* d_random_seed_{};
curandState_t* h_curand_state_{};
curandState_t* d_curand_state_{};
std::array<BatchState, 3> states_{}; std::array<BatchState, 3> states_{};
...@@ -232,7 +277,7 @@ private: ...@@ -232,7 +277,7 @@ private:
TensorMap inputs_; TensorMap inputs_;
TensorMap outputs_; TensorMap outputs_;
std::unordered_map<std::string, void*> sampling_params_; std::vector<std::tuple<std::string, std::byte*, std::byte*>> sampling_params_;
cudaStream_t stream_{}; cudaStream_t stream_{};
cublasMMWrapper* cublas_wrapper_{}; cublasMMWrapper* cublas_wrapper_{};
...@@ -244,8 +289,10 @@ private: ...@@ -244,8 +289,10 @@ private:
std::thread output_thread_; std::thread output_thread_;
std::mutex output_mutex_; std::mutex output_mutex_;
std::condition_variable output_cv_; std::condition_variable output_cv_;
Requests output_reqs_; std::vector<Signal> output_signals_;
bool output_stop_token_{false}; bool output_stop_token_{false};
int* h_output_ids_{};
}; };
} // namespace turbomind } // namespace turbomind
...@@ -45,37 +45,37 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size, ...@@ -45,37 +45,37 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
// no padding // no padding
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, true); qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, false);
// padding is rebuilt for q/k/v_buf_2_ // padding is rebuilt for q/k/v_buf_2_
// [qH + 2kvH, B, S, D] // [qH + 2kvH, B, S, D]
q_buf_2_ = (T*)allocator_->reMalloc( q_buf_2_ = (T*)allocator_->reMalloc(
q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, true); q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, false);
k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_; k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_;
v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_; v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_;
if (use_fmha_) { if (use_fmha_) {
FlashAttentionOp<T> flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_); FlashAttentionOp<T> flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
if (flash_attention.get_workspace_size() > 0) { if (flash_attention.get_workspace_size() > 0) {
qk_buf_float_ = (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), true); qk_buf_float_ = (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), false);
} }
} }
else { else {
// kv heads are repeated for unfused attention // kv heads are repeated for unfused attention
k_cache_buf_ = (T*)allocator_->reMalloc( k_cache_buf_ = (T*)allocator_->reMalloc(
k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true); k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, false);
v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_; v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_;
qk_buf_ = qk_buf_ =
(T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true); (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, false);
// qkv_buf_2_ has padding // qkv_buf_2_ has padding
qkv_buf_2_ = (T*)allocator_->reMalloc( qkv_buf_2_ = (T*)allocator_->reMalloc(
qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, true); qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, false);
} }
// qkv_buf_3_ padding is removed // qkv_buf_3_ padding is removed
qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, true); qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, false);
is_allocate_buffer_ = true; is_allocate_buffer_ = true;
} }
......
...@@ -45,7 +45,7 @@ void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size) ...@@ -45,7 +45,7 @@ void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size)
reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false)); reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));
workspace_ = (float*)allocator_->reMalloc( workspace_ = (float*)allocator_->reMalloc(
workspace_, sizeof(float) * batch_size * local_head_num_ * kMaxSplitK * (size_per_head_ + 2)); workspace_, sizeof(float) * batch_size * local_head_num_ * kMaxSplitK * (size_per_head_ + 2), false);
is_allocate_buffer_ = true; is_allocate_buffer_ = true;
} }
......
...@@ -102,10 +102,6 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -102,10 +102,6 @@ LlamaV2<T>::LlamaV2(size_t head_num,
size_t elem_bits = 0; size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) { if (quant_policy & QuantPolicy::kCacheKVInt8) {
elem_bits = sizeof(int8_t) * 8; elem_bits = sizeof(int8_t) * 8;
if (use_context_fmha) {
TM_LOG_ERROR("use_context_fmha not support int8");
assert(0);
}
} }
else { else {
elem_bits = sizeof(T) * 8; elem_bits = sizeof(T) * 8;
...@@ -406,6 +402,7 @@ void LlamaV2<T>::dynamicDecode(int* token_ids, ...@@ -406,6 +402,7 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
bool* finished, bool* finished,
int* sequence_length, int* sequence_length,
bool* should_stop, bool* should_stop,
curandState_t* curand_state,
TensorMap* inputs, TensorMap* inputs,
TensorMap* outputs, TensorMap* outputs,
const float* logits, const float* logits,
...@@ -450,7 +447,8 @@ void LlamaV2<T>::dynamicDecode(int* token_ids, ...@@ -450,7 +447,8 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
{"output_ids", {MEMORY_GPU, TYPE_INT32, {token_ids_len, batch_size, 1U}, token_ids}}, {"output_ids", {MEMORY_GPU, TYPE_INT32, {token_ids_len, batch_size, 1U}, token_ids}},
{"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}}, {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
{"sequence_length", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}}, {"sequence_length", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
{"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}}}; {"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}},
{"curand_state", {MEMORY_GPU, TYPE_VOID, {batch_size}, curand_state}}};
const std::vector<std::string> optional_outputs{"cum_log_probs", "output_log_probs"}; const std::vector<std::string> optional_outputs{"cum_log_probs", "output_log_probs"};
for (const auto& key : optional_outputs) { for (const auto& key : optional_outputs) {
...@@ -562,7 +560,7 @@ void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>* outputs, ...@@ -562,7 +560,7 @@ void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>* outputs,
if (ec) { if (ec) {
has_error = true; has_error = true;
} }
TM_LOG_INFO("[forward] Request complete for %ld, ec = %d", (long)ids[i], (int)ec); TM_LOG_INFO("[forward] Request complete for %ld, code %d", (long)ids[i], (int)ec);
} }
} }
......
...@@ -151,6 +151,7 @@ private: ...@@ -151,6 +151,7 @@ private:
bool* finished, bool* finished,
int* sequence_length, int* sequence_length,
bool* should_stop, bool* should_stop,
curandState_t* curand_state,
TensorMap* inputs, TensorMap* inputs,
TensorMap* outputs, TensorMap* outputs,
const float* logits, const float* logits,
...@@ -163,16 +164,6 @@ private: ...@@ -163,16 +164,6 @@ private:
size_t token_ids_len, size_t token_ids_len,
size_t batch_size); size_t batch_size);
curandState_t* GetTopKState(int index)
{
return dynamic_decode_layer_->topk_curandstate_buf() + index;
}
curandState_t* GetTopPState(int index)
{
return dynamic_decode_layer_->topp_curandstate_buf() + index;
}
private: private:
friend class LlamaBatch<T>; friend class LlamaBatch<T>;
......
...@@ -87,11 +87,8 @@ bool SequenceManager::Erase(uint64_t id) ...@@ -87,11 +87,8 @@ bool SequenceManager::Erase(uint64_t id)
} }
} }
sequences_.erase(it); sequences_.erase(it);
return true;
} }
else {
throw std::out_of_range(std::to_string(id));
}
return false; return false;
} }
......
...@@ -58,13 +58,13 @@ public: ...@@ -58,13 +58,13 @@ public:
SequenceManager(const SequenceManager&) = delete; SequenceManager(const SequenceManager&) = delete;
SequenceManager(SequenceManager&&) noexcept = default; SequenceManager(SequenceManager&&) noexcept = default;
const Sequence* Create(uint64_t id); [[nodiscard]] const Sequence* Create(uint64_t id);
const Sequence* Get(uint64_t id); [[nodiscard]] const Sequence* Get(uint64_t id);
bool Contains(uint64_t id); [[nodiscard]] bool Contains(uint64_t id);
bool Erase(uint64_t id); [[nodiscard]] bool Erase(uint64_t id);
void UpdateAndSetUnlock(const Sequence& seq); void UpdateAndSetUnlock(const Sequence& seq);
...@@ -74,10 +74,10 @@ public: ...@@ -74,10 +74,10 @@ public:
int swap_out; int swap_out;
}; };
Outcome Materialize(Sequences sequences, [[nodiscard]] Outcome Materialize(Sequences sequences,
std::vector<int> context_lengths, std::vector<int> context_lengths,
const std::vector<uint64_t>& priorities, const std::vector<uint64_t>& priorities,
int step_length); int step_length);
void* OffsetKey(void* block_ptr) void* OffsetKey(void* block_ptr)
{ {
......
...@@ -8,7 +8,11 @@ ...@@ -8,7 +8,11 @@
#include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#include <algorithm>
#include <cstdint>
#include <cub/block/block_reduce.cuh>
#include <type_traits> #include <type_traits>
namespace turbomind { namespace turbomind {
...@@ -606,6 +610,173 @@ void invokeUpdateOutput(int** request_output_ids_ptrs, ...@@ -606,6 +610,173 @@ void invokeUpdateOutput(int** request_output_ids_ptrs,
token_generated); token_generated);
} }
template<int BLOCK_DIM>
__global__ void compactOutputIds(
int* cu_output_ids, const int* output_ids, const int* sequence_lengths, int session_len, bool token_generated)
{
typedef cub::BlockReduce<int, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
const int batch_idx = blockIdx.x;
int end = (batch_idx + BLOCK_DIM - 1) / BLOCK_DIM * BLOCK_DIM; // align to BLOCK_DIM boundary
int count = 0;
for (int i = threadIdx.x; i < end; i += blockDim.x) {
int x = threadIdx.x < batch_idx ? sequence_lengths[threadIdx.x] : 0;
count += BlockReduce(temp_storage).Sum(x);
// https://nvlabs.github.io/cub/classcub_1_1_block_reduce.html
__syncthreads();
}
__shared__ int offset;
if (threadIdx.x == 0) {
offset = count;
}
__syncthreads();
auto dst = cu_output_ids + offset;
const int seq_len = sequence_lengths[batch_idx];
for (int i = threadIdx.x; i < seq_len; i += blockDim.x) {
dst[i] = output_ids[batch_idx * session_len + i];
}
}
void invokeCompactOutputIds(int* cu_output_ids,
const int* output_ids,
const int* sequence_lengths,
int max_session_len,
bool token_generated,
int batch_size,
cudaStream_t stream)
{
constexpr int BLOCK_DIM = 128;
compactOutputIds<BLOCK_DIM><<<batch_size, BLOCK_DIM, 0, stream>>>(
cu_output_ids, output_ids, sequence_lengths, max_session_len, token_generated);
}
template<int N, int C>
struct IndexedCopyParam {
Array<void*, N> src_ptr;
Array<void*, N> dst_ptr;
Array<int, N> stride;
Array<int, C> src_idx;
Array<int, C> dst_idx;
int max_stride;
};
template<class T, int N, int C>
__global__ void indexedCopy(IndexedCopyParam<N, C> param)
{
const int bi = blockIdx.x;
const int si = param.src_idx[bi];
const int di = param.dst_idx[bi];
for (int i = threadIdx.x; i < param.max_stride; i += blockDim.x) {
PRAGMA_UNROLL
for (int k = 0; k < N; ++k) {
if (i < param.stride[k]) {
*((T*)param.dst_ptr[k] + param.stride[k] * di + i) =
*((const T*)param.src_ptr[k] + param.stride[k] * si + i);
}
}
}
}
template<class T, int N>
void invokeIndexedCopyImpl(void** h_src_ptr,
void** h_dst_ptr,
const int* h_elem_sz,
const int* h_src_idx,
const int* h_dst_idx,
int count,
cudaStream_t st)
{
auto invoke = [&](auto max_count) {
constexpr int C = decltype(max_count)::value;
// maximum parameter size: sm<70: 4kB, sm>=70: 32kB
static_assert(sizeof(IndexedCopyParam<N, C>) <= 4096);
IndexedCopyParam<N, C> param{};
std::copy_n(h_src_ptr, N, param.src_ptr.data());
std::copy_n(h_dst_ptr, N, param.dst_ptr.data());
std::transform(h_elem_sz, h_elem_sz + N, param.stride.data(), [](int size) {
// Basic alignment check
FT_CHECK_WITH_INFO(size % sizeof(T) == 0, fmtstr("misalignment: %d %% %d", size, (int)sizeof(T)));
return size / sizeof(T);
});
param.max_stride = *std::max_element(param.stride.begin(), param.stride.end());
auto copy_idx = [](const int* src, int offset, int n, auto dst) {
return src ? (void)std::copy_n(src + offset, n, dst) : std::iota(dst, dst + n, offset);
};
for (int c = 0; c < count; c += C) {
int batch_size = std::min(count - c, C);
copy_idx(h_src_idx, c, batch_size, param.src_idx.data());
copy_idx(h_dst_idx, c, batch_size, param.dst_idx.data());
indexedCopy<T><<<batch_size, 128, 0, st>>>(param);
}
};
if (count <= 4) {
invoke(std::integral_constant<int, 4>{});
}
if (count <= 8) {
invoke(std::integral_constant<int, 8>{});
}
else if (count <= 16) {
invoke(std::integral_constant<int, 16>{});
}
else if (count <= 32) {
invoke(std::integral_constant<int, 32>{});
}
else if (count <= 64) {
invoke(std::integral_constant<int, 64>{});
}
else if (count <= 128) {
invoke(std::integral_constant<int, 128>{});
}
else {
invoke(std::integral_constant<int, 256>{});
}
}
void invokeIndexedCopy(void** h_src_ptr,
void** h_dst_ptr,
const int* h_elem_sz,
const int* h_src_idx,
const int* h_dst_idx,
int count,
int n_copys,
cudaStream_t st)
{
auto args = std::tuple{h_src_ptr, h_dst_ptr, h_elem_sz, h_src_idx, h_dst_idx, count, st};
switch (n_copys) {
case 1:
return std::apply(invokeIndexedCopyImpl<uint32_t, 1>, args);
case 2:
return std::apply(invokeIndexedCopyImpl<uint32_t, 2>, args);
case 3:
return std::apply(invokeIndexedCopyImpl<uint32_t, 3>, args);
case 4:
return std::apply(invokeIndexedCopyImpl<uint32_t, 4>, args);
default:
FT_CHECK(0);
}
}
__global__ void padLastTokenIds(int* token_ids, const int* context_length, int max_context_len, int batch_size)
{
for (int bi = threadIdx.x; bi < batch_size; bi += blockDim.x) {
token_ids[(max_context_len - 1) * batch_size + bi] = token_ids[(context_length[bi] - 1) * batch_size + bi];
}
}
void invokePadLastTokenIds(
int* token_ids, const int* context_length, int max_context_len, int batch_size, cudaStream_t stream)
{
padLastTokenIds<<<1, 512, 0, stream>>>(token_ids, context_length, max_context_len, batch_size);
}
#define VERSION_SWITCH(VERSION, CONST_NAME, ...) \ #define VERSION_SWITCH(VERSION, CONST_NAME, ...) \
[&] { \ [&] { \
if (VERSION == 2) { \ if (VERSION == 2) { \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment