Unverified Commit 911c0a85 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Optimize for throughput (#701)



* tmp

* update

* update

* optimize for throughput

* update

* fix eos

* clean up

* fix serving

* fix indexed copy

* minor

* minor

---------
Co-authored-by: default avatarlvhan028 <lvhan_028@163.com>
parent 65d735ba
......@@ -448,8 +448,8 @@ int main(int argc, char* argv[])
std::vector<int> hBuf(outCount);
ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount);
ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size);
ft::cudaAutoCpy(hBuf.data(), d_output_ids, outCount);
ft::cudaAutoCpy(seq_lens.data(), d_seq_lens, batch_size);
std::cout << "sequence length: ";
for (int i = 0; i < batch_size; ++i) {
......
......@@ -350,7 +350,7 @@ class TurboMindInstance:
outputs = _tm_dict_to_torch_dict(tm_outputs)
output_ids = outputs['output_ids'][:, 0, :]
sequence_length = outputs['sequence_length'].long()[:, 0].cpu()
sequence_length = outputs['sequence_length'].long()[:, 0]
output_ids = [
output_id[s:l] for output_id, s, l in zip(
output_ids, seq_start, sequence_length)
......@@ -366,7 +366,6 @@ class TurboMindInstance:
outputs.append((output[:-1], len_))
else:
outputs.append((output, len_))
yield outputs
if finish:
......
......@@ -236,15 +236,88 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
template<typename T, int N>
struct Array {
T a[N];
__device__ __host__ constexpr T& operator[](int i) noexcept
using value_type = T;
using size_type = int;
using difference_type = int;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using iterator = pointer;
using const_iterator = const_pointer;
static_assert(N > 0);
T __a[N];
__device__ __host__ constexpr reference operator[](size_type i) noexcept
{
return __a[i];
}
__device__ __host__ constexpr const_reference operator[](size_type i) const noexcept
{
return __a[i];
}
__device__ __host__ constexpr reference front() noexcept
{
return *begin();
}
__device__ __host__ constexpr const_reference front() const noexcept
{
return *begin();
}
__device__ __host__ constexpr reference back() noexcept
{
return *(end() - 1);
}
__device__ __host__ constexpr const_reference back() const noexcept
{
return *(end() - 1);
}
__device__ __host__ constexpr pointer data() noexcept
{
return a[i];
return &__a[0];
}
__device__ __host__ constexpr const T& operator[](int i) const noexcept
__device__ __host__ constexpr const_pointer data() const noexcept
{
return &__a[0];
}
__device__ __host__ constexpr iterator begin() noexcept
{
return data();
}
__device__ __host__ constexpr const_iterator begin() const noexcept
{
return data();
}
__device__ __host__ constexpr iterator end() noexcept
{
return data() + N;
}
__device__ __host__ constexpr const_iterator end() const noexcept
{
return data() + N;
}
__device__ __host__ constexpr std::integral_constant<int, N> size() const noexcept
{
return {};
}
__device__ __host__ constexpr std::false_type empty() const noexcept
{
return a[i];
return {};
}
};
......
......@@ -188,6 +188,7 @@ void DynamicDecodeLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_
*
* output_tensors:
* \param output_ids [max_seq_len, batch_size]
* \param curand_state [local_batch_size]
* \param finished [batch_size * beam_width], optional
* \param should_stop [1] on cpu
* \param cum_log_probs [batch_size * beam_width], necessary in beam search
......@@ -276,7 +277,8 @@ void DynamicDecodeLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_
{"input_lengths", input_lengths.slice({local_batch_size, beam_width}, local_batch_offset)});
}
TensorMap decode_output_tensors({{"output_ids", output_tensors->at("output_ids")}});
TensorMap decode_output_tensors({{"output_ids", output_tensors->at("output_ids")}, //
{"curand_state", output_tensors->at("curand_state")}});
if (output_tensors->isExist("sequence_length")) {
Tensor sequence_length = output_tensors->at("sequence_length");
decode_output_tensors.insert(
......
......@@ -53,15 +53,6 @@ protected:
int* h_pinned_finished_sum_ = nullptr;
public:
curandState_t* topk_curandstate_buf()
{
return static_cast<BaseSamplingLayer<T>*>(topk_decode_)->curandstate_buf();
}
curandState_t* topp_curandstate_buf()
{
return static_cast<BaseSamplingLayer<T>*>(topp_decode_)->curandstate_buf();
}
DynamicDecodeLayer(size_t vocab_size,
size_t vocab_size_padded,
int end_id,
......
......@@ -30,10 +30,6 @@ template<typename T>
void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
curandstate_buf_ = reinterpret_cast<curandState_t*>(
allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, false));
random_seeds_buf_ = reinterpret_cast<unsigned long long*>(
allocator_->reMalloc(random_seeds_buf_, sizeof(unsigned long long) * batch_size, false));
temperature_buf_ =
reinterpret_cast<float*>(allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false));
repetition_penalty_buf_ =
......@@ -58,8 +54,6 @@ void BaseSamplingLayer<T>::freeBuffer()
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) {
allocator_->free((void**)(&curandstate_buf_));
allocator_->free((void**)(&random_seeds_buf_));
allocator_->free((void**)(&temperature_buf_));
allocator_->free((void**)(&repetition_penalty_buf_));
allocator_->free((void**)(&min_lengths_buf_));
......@@ -128,32 +122,6 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
Tensor runtime_top_p = runtime_args->isExist("runtime_top_p") ? runtime_args->at("runtime_top_p") : Tensor();
allocateBuffer(batch_size, runtime_top_k, runtime_top_p);
// If runtime argument has single random seed, using this random seed to initialize the random table of all
// sentences. If the argument has [batch_size] random seeds, initializing the random table by different random seeds
// respectively. If no random seed, initialize the random table of all sentences by 0 directly.
if (runtime_args->isExist("random_seed")) {
Tensor random_seeds = runtime_args->at("random_seed");
FT_CHECK_WITH_INFO(random_seeds.shape.size() == 1
&& (random_seeds.size() == 1 || random_seeds.size() == batch_size),
fmtstr("random_seeds must be of shape [1] or [batch_size(%ld)], got random_seeds.shape=%s",
batch_size,
vec2str(random_seeds.shape).c_str()));
if (random_seeds.size() == 1) {
invokeCurandInitialize(curandstate_buf_, batch_size, random_seeds.getVal<unsigned long long>(), stream_);
sync_check_cuda_error();
}
else {
unsigned long long* random_seed_ptr = random_seeds.getPtr<unsigned long long>();
cudaAutoCpy(random_seeds_buf_, random_seed_ptr, batch_size, stream_);
invokeCurandBatchInitialize(curandstate_buf_, batch_size, random_seeds_buf_, stream_);
sync_check_cuda_error();
}
}
else {
// Initialize curand states using the default seed 0.
invokeCurandInitialize(curandstate_buf_, batch_size, 0, stream_);
}
// Setup penalties.
const float default_temperature = 1.0f;
Tensor temperature = runtime_args->isExist("temperature") ?
......
......@@ -33,10 +33,8 @@ protected:
size_t vocab_size_;
size_t vocab_size_padded_;
size_t sampling_workspace_size_;
void* sampling_workspace_ = nullptr;
curandState_t* curandstate_buf_ = nullptr;
unsigned long long* random_seeds_buf_ = nullptr;
size_t sampling_workspace_size_;
void* sampling_workspace_ = nullptr;
float* temperature_buf_ = nullptr;
float* repetition_penalty_buf_ = nullptr;
......@@ -59,11 +57,6 @@ protected:
virtual void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p);
public:
curandState_t* curandstate_buf()
{
return curandstate_buf_;
}
BaseSamplingLayer(size_t max_batch_size,
size_t vocab_size,
size_t vocab_size_padded,
......
......@@ -16,6 +16,7 @@
*/
#include <float.h>
#include <sstream>
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/kernels/sampling_topp_kernels.h"
......@@ -199,6 +200,7 @@ void TopKSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
// output_tensors:
// output_ids [max_seq_len, batch_size]
// curand_state [local_batch_size]
// finished [local_batch_size], optional
// sequence_length [local_batch_size], optional
// cum_log_probs [batch_size], must be float*, optional
......@@ -255,7 +257,7 @@ void TopKSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
output_tensors->at("finished", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr<bool>(),
cum_log_probs,
output_log_probs,
curandstate_buf_ + ite * local_batch_size,
output_tensors->at("curand_state").getPtr<curandState_t>() + ite * local_batch_size,
(int)runtime_max_top_k_, // useless because runtime_top_k_buf_ is never nullptr. Keep for legacy.
(int*)(runtime_top_k_buf_ + ite * local_batch_size),
1.0f, // useless because runtime_top_p_buf_ is never nullptr. Keep for legacy.
......
......@@ -40,8 +40,6 @@ private:
using BaseSamplingLayer<T>::sampling_workspace_size_;
using BaseSamplingLayer<T>::sampling_workspace_;
using BaseSamplingLayer<T>::curandstate_buf_;
using BaseSamplingLayer<T>::random_seeds_buf_;
using BaseSamplingLayer<T>::skip_decode_buf_;
using BaseSamplingLayer<T>::skip_decode_;
using BaseSamplingLayer<T>::skip_any_;
......
......@@ -132,7 +132,7 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
topp_id_vals_buf_,
topp_offset_buf_,
begin_topp_offset_buf_,
curandstate_buf_,
nullptr, // not used when workspace is null
batch_size,
vocab_size_padded_,
nullptr,
......@@ -267,6 +267,7 @@ void TopPSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
* output_tensors:
* \param output_ids [max_seq_len, batch_size]
* \param curand_state [local_batch_size]
* \param finished [local_batch_size], optional
* \param sequence_length [local_batch_size], optional
* \param cum_log_probs [batch_size], must be float*, optional
......@@ -319,7 +320,7 @@ void TopPSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
topp_id_vals_buf_,
topp_offset_buf_,
begin_topp_offset_buf_,
curandstate_buf_ + ite * local_batch_size,
output_tensors->at("curand_state").getPtr<curandState_t>() + ite * local_batch_size,
local_batch_size,
vocab_size_padded_,
input_tensors->at("end_id").getPtr<int>(),
......
......@@ -48,8 +48,6 @@ private:
using BaseSamplingLayer<T>::sampling_workspace_size_;
using BaseSamplingLayer<T>::sampling_workspace_;
using BaseSamplingLayer<T>::curandstate_buf_;
using BaseSamplingLayer<T>::random_seeds_buf_;
using BaseSamplingLayer<T>::skip_decode_buf_;
using BaseSamplingLayer<T>::skip_decode_;
using BaseSamplingLayer<T>::skip_any_;
......
......@@ -2,6 +2,7 @@
#include "src/turbomind/models/llama/LlamaBatch.h"
#include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaV2.h"
......@@ -16,12 +17,15 @@
#include "src/turbomind/utils/logger.h"
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <iomanip>
#include <iterator>
#include <mutex>
#include <numeric>
#include <sstream>
#include <unordered_map>
#include <utility>
namespace turbomind {
......@@ -142,7 +146,9 @@ void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r
template<typename T>
auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector<Signal>
{
NvtxScope scope("stop_request");
std::vector<Signal> signals;
int count = 0;
for (const auto& r : requests) {
int ec = Request::kFail;
// find matching active sequence
......@@ -150,29 +156,25 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector
// stop & optionally erase active sequence
if (state_->requests[i] && state_->requests[i]->id == r->id) {
ec = 0;
CompleteRequest(i, true, r->end_flag);
state_->requests[i].reset();
signals.push_back(Interrupt(i, true, r->end_flag));
++count;
break;
}
}
// mismatch, try erase inactive sequence, in this case there is no active request to finish
// mismatch, try erase inactive sequence, in this case there is no active request to interrupt
if (ec && r->end_flag) {
ec = 0;
sequence_manager_->Erase(r->id);
if (sequence_manager_->Erase(r->id)) {
ec = 0;
}
}
// clear output buffers (prevent leaking conversations) if request is successful
if (ec == 0) {
signals.push_back([=] {
if (rank_ == 0) {
std::unique_lock lock{output_mutex_};
output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
r->signal.set_value(ec);
}
auto& output_ids = r->outputs[rank_].at("output_ids");
auto& sequence_length = r->outputs[rank_].at("sequence_length");
Clear(output_ids.getPtr<int>(), output_ids.shape.at(2));
Clear(sequence_length.getPtr<int>(), 1);
check_cuda_error(cudaStreamSynchronize(stream_));
}
signals.push_back([=] { r->signal.set_value(ec); });
});
}
if (count) {
check_cuda_error(cudaStreamSynchronize(stream_));
}
return signals;
}
......@@ -180,25 +182,29 @@ auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector
template<typename T>
void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
{
auto& state = *incoming_;
NvtxScope scope("infer_request");
auto& state = *incoming_;
FT_CHECK(state.size == 0);
FT_CHECK(state.active_size == 0);
int i = 0;
for (const auto& r : requests) {
std::vector<int> existing_idx;
// sanity check, incoming request in previous iter should have been moved to `state_`
FT_CHECK(!state.requests[i]);
int idx = 0;
for (const auto& r : requests) {
FT_CHECK(!state.requests[idx]);
TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id);
if (rank_ == 0) {
TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id);
}
state.requests[i] = r;
state.requests[idx] = r;
// get sequence for the request
state.sequences[i] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);
state.sequences[idx] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);
FT_CHECK(state.sequences[idx]);
auto& seq = *state.sequences[i];
auto& seq = *state.sequences[idx];
if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
/// TODO: revise step setting
......@@ -216,7 +222,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
const int* input_ids = r->inputs[rank_].getPtr<int>("input_ids");
// `output_ids` contains all token ids of the sequences
const auto output_ids_base = state.output_ids + session_len_ * i;
const auto output_ids_base = state.output_ids + session_len_ * idx;
auto output_ids = output_ids_base;
// copy history tokens
......@@ -230,21 +236,21 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
}
// total context length (history + input)
state.h_context_length[i] = output_ids - output_ids_base;
state.h_finished[i] = false;
state.h_context_length[idx] = output_ids - output_ids_base;
state.h_finished[idx] = false;
const int request_output_len = state.requests[i]->inputs[rank_].getVal<int>("request_output_len");
state.seq_len_limit[i] = state.h_context_length[i] + request_output_len;
const int request_output_len = state.requests[idx]->inputs[rank_].getVal<int>("request_output_len");
state.seq_len_limit[idx] = state.h_context_length[idx] + request_output_len;
// `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
// the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
if (state.seq_len_limit[i] >= session_len_) {
state.seq_len_limit[i] = session_len_ - 1;
if (state.seq_len_limit[idx] >= session_len_) {
state.seq_len_limit[idx] = session_len_ - 1;
if (rank_ == 0) {
const int trunc_output_len = state.seq_len_limit[i] - state.h_context_length[i];
const int trunc_output_len = state.seq_len_limit[idx] - state.h_context_length[idx];
TM_LOG_WARNING(
"[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
(long)seq.id,
state.h_context_length[i],
state.h_context_length[idx],
request_output_len,
(int)session_len_,
trunc_output_len);
......@@ -260,7 +266,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
}
else if (model_->attn_params_.rope_scaling_factor >= 1.f) { // infer by `seq_len_limit`
scaling_factor = model_->attn_params_.rope_scaling_factor;
auto max_seq_len = state.seq_len_limit[i];
auto max_seq_len = state.seq_len_limit[idx];
auto max_pos_emb = model_->attn_params_.max_position_embeddings;
if (max_seq_len > max_pos_emb) {
scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
......@@ -277,22 +283,45 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
seq.rope_theta);
}
}
state.h_rope_theta[i] = seq.rope_theta;
state.h_rope_theta[idx] = seq.rope_theta;
// recover device states if not a new sequence
if (!r->start_flag) {
Copy((curandState_t*)seq.random_state.data() + 0, 1, (curandState_t*)state.top_k_curand_state);
Copy((curandState_t*)seq.random_state.data() + 1, 1, (curandState_t*)state.top_p_curand_state);
if (r->start_flag) {
// prepare to initialize random state for new sequence
h_random_seed_[idx] = r->inputs[rank_].getVal<unsigned long long>("random_seed", 0);
}
else {
// Recover device states if not a new sequence
h_curand_state_[existing_idx.size()] = *(curandState_t*)seq.random_state.data();
existing_idx.push_back(idx);
}
// ! SHARED STATE IS MODIFIED, BARRIER SYNCHRONIZATION REQUIRED
// assign priority based on arrival time
r->priority = request_count_++;
if (rank_ == 0) {
r->priority = request_count_++;
}
// increment pointer
i++;
idx++;
}
incoming_->size = i;
state.size = idx;
// when there are new sequences
if (state.size != existing_idx.size()) {
// copy random seeds to device
Copy(h_random_seed_, state.size, d_random_seed_);
// initialize random states
invokeCurandBatchInitialize(state.curand_state, state.size, d_random_seed_, stream_);
sync_check_cuda_error();
}
if (!existing_idx.empty()) {
// copy existing curand states to device
Copy(h_curand_state_, existing_idx.size(), d_curand_state_);
// insert the states to their correct positions in the batch
IndexedCopy({}, existing_idx, std::tuple{d_curand_state_, state.curand_state, 1});
}
}
template<typename T>
......@@ -375,14 +404,11 @@ bool LlamaBatch<T>::Initialize()
// Copy sequence states to back buffer
FT_CHECK(back_->size == 0 && back_->active_size == 0);
std::vector<std::tuple<BatchState*, BatchState*, int, int>> cpys;
for (const auto& i : idxs) {
auto& s = *sequences[i];
if (exchange) {
const auto& [state, idx] = coords[i];
// backup random states from dynamic decode layers for swap-outs
if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
SaveRandomState(*state, idx);
}
// mark swap-ins
if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
state->is_swap_in[idx] = 1;
......@@ -391,8 +417,9 @@ bool LlamaBatch<T>::Initialize()
if (s.status == Sequence::kActive) {
++back_->active_size;
}
CopyState(coords[i], {back_, back_->size++});
cpys.emplace_back(coords[i].first, back_, coords[i].second, back_->size++);
}
CopyState(cpys);
// Swap the buffers
std::swap(state_, back_);
......@@ -421,6 +448,9 @@ bool LlamaBatch<T>::Initialize()
// cumulative num of blocks
h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();
FT_CHECK_WITH_INFO(h_cu_block_counts_[i + 1] <= sequence_manager_->max_block_count(),
std::to_string(h_cu_block_counts_[i + 1]));
k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
});
......@@ -444,41 +474,60 @@ bool LlamaBatch<T>::Initialize()
}
template<typename T>
void LlamaBatch<T>::CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst)
void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc)
{
const auto& [src, i] = _src;
const auto& [dst, j] = _dst;
std::vector<int> idxs(desc.size());
std::iota(idxs.begin(), idxs.end(), 0);
FT_CHECK((bool)src->requests[i]);
FT_CHECK(!(bool)dst->requests[j]);
std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return desc[i] < desc[j]; });
dst->h_context_length[j] = src->h_context_length[i];
dst->h_finished[j] = src->h_finished[i];
dst->h_rope_theta[j] = src->h_rope_theta[i];
dst->seq_len_limit[j] = src->seq_len_limit[i];
dst->sequences[j] = src->sequences[i];
dst->is_swap_in[j] = src->is_swap_in[i];
dst->requests[j] = std::move(src->requests[i]);
auto get_signature = [&](int i) -> std::pair<BatchState*, BatchState*> {
return std::make_pair(std::get<0>(desc[idxs[i]]), std::get<1>(desc[idxs[i]]));
};
Copy(src->output_ids + i * session_len_, src->h_context_length[i], dst->output_ids + j * session_len_);
std::vector<int> offsets;
auto current = get_signature(0);
offsets.push_back(0);
for (int i = 0; i < idxs.size(); ++i) {
if (auto signature = get_signature(i); signature != current) {
current = signature;
offsets.push_back(i);
}
}
offsets.push_back(idxs.size());
Copy((curandState_t*)src->top_k_curand_state + i, 1, (curandState_t*)dst->top_k_curand_state + j);
Copy((curandState_t*)src->top_p_curand_state + i, 1, (curandState_t*)dst->top_p_curand_state + j);
}
for (int bi = 1; bi < offsets.size(); ++bi) {
int beg = offsets[bi - 1];
int end = offsets[bi];
template<typename T>
void LlamaBatch<T>::SaveRandomState(BatchState& state, int idx)
{
Copy(model_->GetTopKState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
Copy(model_->GetTopPState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
}
if (beg == end) {
continue;
}
template<typename T>
void LlamaBatch<T>::LoadRandomState(BatchState& state, int idx)
{
dbg(idx);
Copy((curandState_t*)state.top_k_curand_state + idx, 1, model_->GetTopKState(idx));
Copy((curandState_t*)state.top_p_curand_state + idx, 1, model_->GetTopPState(idx));
auto [s, d] = get_signature(beg);
std::vector<int> s_idx;
std::vector<int> d_idx;
for (int i = beg; i < end; ++i) {
s_idx.push_back(std::get<2>(desc[idxs[i]]));
d_idx.push_back(std::get<3>(desc[idxs[i]]));
}
IndexedCopy(s_idx,
d_idx,
std::tuple{s->output_ids, d->output_ids, session_len_},
std::tuple{s->curand_state, d->curand_state, 1});
}
for (const auto& [s, d, si, di] : desc) {
d->h_context_length[di] = s->h_context_length[si];
d->h_finished[di] = s->h_finished[si];
d->h_rope_theta[di] = s->h_rope_theta[si];
d->seq_len_limit[di] = s->seq_len_limit[si];
d->sequences[di] = s->sequences[si];
d->is_swap_in[di] = s->is_swap_in[si];
d->requests[di] = s->requests[si];
}
}
template<typename T>
......@@ -527,7 +576,6 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
token_ids_buf_ = (int*)allocator_->reMalloc(token_ids_buf_, sizeof(int) * batchxbeam * session_len * 2, true);
end_ids_buf_ = (int*)allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false);
finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false);
seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false);
......@@ -543,30 +591,45 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
template<typename T>
void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
{
stop_words_buf_ =
(int*)allocator_->reMalloc(stop_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
bad_words_buf_ =
(int*)allocator_->reMalloc(bad_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
d_stop_words_ = (int*)allocator_->reMalloc(d_stop_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
d_bad_words_ = (int*)allocator_->reMalloc(d_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
h_stop_words_ =
(int*)allocator_->reMalloc(h_stop_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
h_bad_words_ =
(int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true);
h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true);
h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true);
h_repetition_penalty_ =
(float*)allocator_->reMalloc(h_repetition_penalty_, sizeof(float) * max_batch_size, true, true);
h_random_seed_ = (uint64_t*)allocator_->reMalloc(h_random_seed_, sizeof(uint64_t) * max_batch_size, true, true);
sampling_params_ = {{"stop_words_list", stop_words_buf_},
{"bad_words_list", bad_words_buf_},
{"runtime_top_k", h_runtime_top_k_},
{"runtime_top_p", h_runtime_top_p_},
{"temperature", h_temperature_},
{"repetition_penalty", h_repetition_penalty_},
{"random_seed", h_random_seed_}};
h_random_seed_ = (unsigned long long*)allocator_->reMalloc(
h_random_seed_, sizeof(unsigned long long) * max_batch_size, true, true);
d_random_seed_ = (unsigned long long*)allocator_->reMalloc(
d_random_seed_, sizeof(unsigned long long) * max_batch_size, true, false);
h_curand_state_ =
(curandState_t*)allocator_->reMalloc(h_curand_state_, sizeof(curandState_t) * max_batch_size, true, true);
d_curand_state_ =
(curandState_t*)allocator_->reMalloc(d_curand_state_, sizeof(curandState_t) * max_batch_size, true, false);
d_end_ids_buf_ = (int*)allocator_->reMalloc(d_end_ids_buf_, sizeof(int) * max_batch_size, false);
h_end_ids_buf_ = (int*)allocator_->reMalloc(h_end_ids_buf_, sizeof(int) * max_batch_size, false, true);
sampling_params_ = {
{"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
{"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_},
{"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr},
{"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr},
{"temperature", (std::byte*)h_temperature_, nullptr},
{"repetition_penalty", (std::byte*)h_repetition_penalty_, nullptr},
};
for (auto& s : states_) {
s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true);
s.top_k_curand_state = allocator_->reMalloc(s.top_k_curand_state, sizeof(curandState_t) * max_batch_size, true);
s.top_p_curand_state = allocator_->reMalloc(s.top_p_curand_state, sizeof(curandState_t) * max_batch_size, true);
s.curand_state =
(curandState_t*)allocator_->reMalloc(s.curand_state, sizeof(curandState_t) * max_batch_size, true);
}
const size_t max_block_count = sequence_manager_->max_block_count();
......@@ -604,6 +667,9 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
(int*)allocator_->reMalloc(h_request_output_ids_lens_, sizeof(int) * max_batch_size, true, true);
h_request_seqlen_ptrs_ =
(int**)allocator_->reMalloc(h_request_seqlen_ptrs_, sizeof(int*) * max_batch_size, true, true);
h_output_ids_ =
(int*)allocator_->reMalloc(h_output_ids_, sizeof(int) * max_batch_size * session_len_, false, true);
}
is_allocate_persistant_buffer_ = true;
......@@ -648,7 +714,9 @@ void LlamaBatch<T>::FreeBuffer()
allocator_->free((void**)&token_ids_buf_);
allocator_->free((void**)&end_ids_buf_);
allocator_->free((void**)&d_end_ids_buf_);
allocator_->free((void**)&h_end_ids_buf_, true);
allocator_->free((void**)&finished_buf_);
allocator_->free((void**)&seq_limit_len_);
......@@ -662,11 +730,22 @@ void LlamaBatch<T>::FreeBuffer()
}
if (is_allocate_persistant_buffer_) {
allocator_->free((void**)&d_stop_words_);
allocator_->free((void**)&h_stop_words_, true);
allocator_->free((void**)&d_bad_words_);
allocator_->free((void**)&h_bad_words_, true);
allocator_->free((void**)&d_random_seed_);
allocator_->free((void**)&h_random_seed_, true);
allocator_->free((void**)&d_curand_state_);
allocator_->free((void**)&h_curand_state_, true);
for (auto& s : states_) {
allocator_->free((void**)&s.h_context_length, true);
allocator_->free((void**)&s.h_finished, true);
allocator_->free((void**)&s.h_rope_theta, true);
allocator_->free((void**)&s.output_ids);
allocator_->free((void**)&s.curand_state);
}
allocator_->free((void**)&h_tmp_k_ptrs_, true);
allocator_->free((void**)&h_tmp_v_ptrs_, true);
......@@ -681,6 +760,8 @@ void LlamaBatch<T>::FreeBuffer()
allocator_->free((void**)&h_request_output_ids_lens_, true);
allocator_->free((void**)&h_request_seqlen_ptrs_, true);
allocator_->free((void**)&h_output_ids_, true);
is_allocate_persistant_buffer_ = false;
}
}
......@@ -723,14 +804,15 @@ LlamaBatch<T>::LlamaBatch(int max_batch_size,
template<typename T>
void LlamaBatch<T>::InitializeSampling()
{
NvtxScope _("InitSampling");
const int batch_size = state_->active_size;
TensorMap inputs;
for (const auto& param : sampling_params_) {
for (const auto& [name, h_ptr, d_ptr] : sampling_params_) {
// find an exemplar that matches the param name
const Tensor* ptr{};
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]->inputs[rank_].isExist(param.first)) {
ptr = &state_->requests[i]->inputs[rank_].at(param.first);
if (state_->requests[i]->inputs[rank_].isExist(name)) {
ptr = &state_->requests[i]->inputs[rank_].at(name);
break;
}
}
......@@ -741,42 +823,50 @@ void LlamaBatch<T>::InitializeSampling()
FT_CHECK(shape[0] == 1);
shape[0] = batch_size;
const int size_in_bytes = ref.sizeBytes();
Clear((std::byte*)param.second, size_in_bytes * batch_size);
memset(h_ptr, 0, size_in_bytes * batch_size);
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]->inputs[rank_].isExist(param.first)) {
auto& src = state_->requests[i]->inputs[rank_].at(param.first);
if (state_->requests[i]->inputs[rank_].isExist(name)) {
Tensor& src = state_->requests[i]->inputs[rank_].at(name);
FT_CHECK(ref.shape == src.shape);
Copy(src.getPtr<std::byte>(), size_in_bytes, (std::byte*)param.second + size_in_bytes * i);
std::copy_n(src.getPtr<std::byte>(), size_in_bytes, h_ptr + size_in_bytes * i);
}
}
inputs.insert({param.first, {ref.where, ref.type, shape, param.second}});
if (d_ptr) {
Copy(h_ptr, batch_size * size_in_bytes, d_ptr);
}
inputs.insert({name, {d_ptr ? MEMORY_GPU : MEMORY_CPU, ref.type, shape, d_ptr ? d_ptr : h_ptr}});
if (debug_ && rank_ == 0) {
TM_LOG_INFO("[initializeSampling] %s", format({param.first, inputs.at(param.first)}).c_str());
TM_LOG_INFO("[initializeSampling] %s", format({name, inputs.at(name)}).c_str());
}
}
}
// init for eos
std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);
Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_);
inputs.insert({"end_id", {MEMORY_GPU, TYPE_INT32, {(size_t)batch_size}, d_end_ids_buf_}});
inputs_ = std::move(inputs);
model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);
// recover random states if not a new request
for (int i = 0; i < batch_size; ++i) {
if (!state_->requests[i]->start_flag && state_->is_swap_in[i]) {
LoadRandomState(*state_, i);
}
}
handleOptArg(&inputs_, "end_id", end_ids_buf_, model_->end_id_, batch_size);
cudaStreamSynchronize(0);
}
template<typename T>
auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
{
NvtxScope _("InitGen");
const int batch_size = state_->active_size;
const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
Copy(state_->h_context_length, batch_size, context_length_buf_); // also referenced in `SetOutputTensors`
Copy(context_length_buf_, batch_size, sequence_lengths_);
// `sequence_lengths_` will be increased by dynamic decode
// note that in decoder and in output "sequence length" has different semantic
// - in decoder it means length of sequence that has kv cache already computed
// - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
sync_check_cuda_error();
Clear(token_ids_buf_, batch_size * session_len_);
invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
sync_check_cuda_error();
......@@ -787,22 +877,7 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
// ABCDEFGHi -> ABCDEFGHi i
// ABCDEFGh ABCDEFGh h
// ABCd ABCd d
for (int i = 0; i < batch_size; ++i) {
auto token_ids = token_ids_buf_ + i;
auto p_src = state_->h_context_length[i] - 1;
auto p_dst = max_context_len - 1;
if (p_src != p_dst) { // dst and src of `cudaMemcpyAsync` must not overlap
Copy(token_ids + p_src * batch_size, 1, token_ids + p_dst * batch_size);
}
}
Copy(state_->h_context_length, batch_size, context_length_buf_); // also referenced in `SetOutputTensors`
Copy(context_length_buf_, batch_size, sequence_lengths_);
// `sequence_lengths_` will be increased by dynamic decode
// note that in decoder and in output "sequence length" has different semantic
// - in decoder it means length of sequence that has kv cache already computed
// - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
invokePadLastTokenIds(token_ids_buf_, context_length_buf_, max_context_len, batch_size, stream_);
sync_check_cuda_error();
// used for dispatching split-k decoding kernels
......@@ -846,21 +921,19 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size);
TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len);
TM_LOG_INFO("[initGen] slot sequence_id context_len seq_limit_len finished");
for (int i = 0; i < batch_size; ++i) {
TM_LOG_INFO("[initGen] %4d %11ld %11d %13d %8d",
i,
(long)state_->sequences[i]->id,
state_->h_context_length[i],
(int)h_seq_limit_len_[i],
(int)state_->h_finished[i]);
if (debug_) {
TM_LOG_INFO("[initGen] slot sequence_id context_len seq_limit_len finished");
for (int i = 0; i < batch_size; ++i) {
TM_LOG_INFO("[initGen] %4d %11ld %11d %13d %8d",
i,
(long)state_->sequences[i]->id,
state_->h_context_length[i],
(int)h_seq_limit_len_[i],
(int)state_->h_finished[i]);
}
}
}
// for (int i = 0; i < batch_size; ++i) {
// gSequenceIds(i) = state_->requests[i]->id;
// }
return GenerationState{max_context_len, start_step, sum_seq_len, max_seq_len};
}
......@@ -908,6 +981,9 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
decoder_output_buf_,
batch_size);
/// sync for better NVTX visualization, THIS IS NOT NEEDED
// check_cuda_error(cudaStreamSynchronize(stream_));
// stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
// not supported yet.
bool should_stop{};
......@@ -915,12 +991,13 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
finished_buf_,
sequence_lengths_,
&should_stop,
state_->curand_state,
&inputs_,
&outputs_,
logits_buf_,
seq_limit_len_,
context_length_buf_,
end_ids_buf_,
d_end_ids_buf_,
g.step,
0,
g.max_init_ctx_len,
......@@ -960,6 +1037,7 @@ bool LlamaBatch<T>::Generate(GenerationState& g)
template<typename T>
void LlamaBatch<T>::ContextDecode()
{
NvtxScope _("prefill");
const auto batch_size = state_->active_size;
int base = -1;
......@@ -987,8 +1065,8 @@ void LlamaBatch<T>::ContextDecode()
Copy(state_->h_rope_theta, batch_size, rope_theta_);
Copy(h_input_length_buf_, batch_size, input_length_buf_);
check_cuda_error(cudaStreamSynchronize(stream_));
const auto tick = std::chrono::high_resolution_clock::now();
// check_cuda_error(cudaStreamSynchronize(stream_));
// const auto tick = std::chrono::high_resolution_clock::now();
if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
......@@ -1039,7 +1117,7 @@ void LlamaBatch<T>::ContextDecode()
auto input_ids = context_decoder_ids_buf_;
TM_LOG_INFO("first = %d, last = %d", first, last);
for (int i = first; i < last; ++i) {
TM_LOG_INFO("session_len = %d, input_length = %d", session_len_, h_input_length_buf_[i]);
// TM_LOG_INFO("session_len = %d, input_length = %d", session_len_, h_input_length_buf_[i]);
input_ids = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
dbg(i, h_input_length_buf_[i]);
h_tmp_k_ptrs_[i] = k_ptr;
......@@ -1069,15 +1147,6 @@ void LlamaBatch<T>::ContextDecode()
dbg(first, last);
dbg(k_block_ptrs_, v_block_ptrs_);
if (1) {
std::vector<int> input_len(sub_batch_size);
std::vector<int> context_len(sub_batch_size);
Copy(input_length_buf_ + first, sub_batch_size, input_len.data());
Copy(context_length_buf_ + first, sub_batch_size, context_len.data());
cudaStreamSynchronize(stream_);
dbg(input_len, context_len);
}
model_->contextDecode(nullptr,
k_block_ptrs_,
v_block_ptrs_,
......@@ -1112,11 +1181,11 @@ void LlamaBatch<T>::ContextDecode()
}
}
check_cuda_error(cudaStreamSynchronize(stream_));
const auto tock = std::chrono::high_resolution_clock::now();
if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count());
}
// check_cuda_error(cudaStreamSynchronize(stream_));
// const auto tock = std::chrono::high_resolution_clock::now();
// if (rank_ == 0) {
// TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count());
// }
}
template<typename T>
......@@ -1167,65 +1236,43 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_
}
template<typename T>
auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vector<Signal>
{
NvtxScope scope("Finish");
const int batch_size = state_->active_size;
// secure info needed by `Initialize()`
Copy(finished_buf_, batch_size, state_->h_finished);
// [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
invokeGatherOutput(state_->output_ids,
token_ids_buf_,
context_length_buf_,
g.max_init_ctx_len,
g.step,
session_len_,
batch_size,
stream_);
sync_check_cuda_error();
// invariant: context_length = sequence_length + 1
invokePlusScalar(sequence_lengths_, 1, batch_size, stream_);
Copy(state_->output_ids, batch_size * session_len_, h_output_ids_);
Copy(finished_buf_, batch_size, state_->h_finished);
Copy(sequence_lengths_, batch_size, state_->h_context_length);
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
if constexpr (0) {
std::unique_lock<std::mutex> lock;
if (rank_ == 0) {
NvtxScope _("acquire_outputs");
// wait for previous output operations
lock = std::unique_lock{output_mutex_};
output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
}
SetOutputTensors(g);
check_cuda_error(cudaStreamSynchronize(stream_));
check_cuda_error(cudaStreamSynchronize(stream_));
if (rank_ == 0) {
NvtxScope _("signal_output_thread");
// enqueue new output requests
for (int i = 0; i < batch_size; ++i) {
FT_CHECK(state_->requests[i] != nullptr);
if (state_->requests[i]->stream_cb) {
output_reqs_.push_back(state_->requests[i]);
}
}
lock.unlock();
// notify output thread when we do have stream cbs to call
if (!output_reqs_.empty()) {
output_cv_.notify_one();
}
}
// invariant: context_length = sequence_length + 1
for (int i = 0; i < batch_size; ++i) {
++state_->h_context_length[i];
}
else {
SetOutputTensors(g);
check_cuda_error(cudaStreamSynchronize(stream_));
{
NvtxScope _("output_cb");
if (rank_ == 0 && model_->ffi_lock_) {
model_->ffi_lock_(1);
}
for (int i = 0; i < batch_size; ++i) {
FT_CHECK(state_->requests[i] != nullptr);
if (state_->requests[i]->stream_cb && rank_ == 0) {
state_->requests[i]->stream_cb(&state_->requests[i]->outputs[rank_].get());
}
}
if (rank_ == 0 && model_->ffi_lock_) {
model_->ffi_lock_(0);
{ // set output tokens ids and sequence length
int* output_ptr = h_output_ids_;
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
const int count = state_->h_context_length[i] - 1 + int(g.step != g.max_init_ctx_len);
// TODO: sync history output tokens at when receiving the request and copy only the last token here
std::copy(output_ptr, output_ptr + count, h_request_output_ids_ptrs_[i]);
*h_request_seqlen_ptrs_[i] = count;
}
output_ptr += session_len_;
}
}
......@@ -1247,66 +1294,37 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
std::vector<Signal> signals;
{
NvtxScope _("prepare_completion_signal");
NvtxScope _("stream_and_completion_signal");
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i] && state_->h_finished[i]) {
CompleteRequest(i, false, false);
signals.push_back([r = std::move(state_->requests[i])] { r->signal.set_value(0); });
if (state_->requests[i]) {
if (state_->h_finished[i]) {
// Interrupt finished sequences and move the request handle into the signal closure
signals.push_back(Interrupt(i));
++finished_count;
}
else if (state_->requests[i]->stream_cb) {
// Create signals by copying the request handles for non-finished streaming requests
signals.push_back([this, r = state_->requests[i]] {
if (rank_ == 0) {
r->stream_cb(&r->outputs[rank_].get());
}
});
}
}
}
if (finished_count) {
// synchronize for interrupted sequences
check_cuda_error(cudaStreamSynchronize(stream_));
}
}
return signals;
}
template<typename T>
void LlamaBatch<T>::SetOutputTensors(const GenerationState& g)
{
NvtxScope scope("SetOutputTensors");
// dbg(g.max_init_ctx_len);
const auto batch_size = state_->active_size;
// [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
invokeGatherOutput(state_->output_ids,
token_ids_buf_,
context_length_buf_,
g.max_init_ctx_len,
g.step,
session_len_,
batch_size,
stream_);
sync_check_cuda_error();
if constexpr (1) {
invokeUpdateOutput(request_output_ids_ptrs_,
request_seqlen_ptrs_,
state_->output_ids,
sequence_lengths_,
request_output_ids_lens_,
session_len_,
g.step > g.max_init_ctx_len,
batch_size,
stream_);
sync_check_cuda_error();
}
else {
// for (int i = 0; i < batch_size; ++i) {
// if (state_->requests[i]) {
// auto& output_ids = state_->requests[i]->outputs[rank_].at("output_ids");
// auto& sequence_length = state_->requests[i]->outputs[rank_].at("sequence_length");
// Copy(state_->output_ids + i * session_len_, output_ids.shape.at(2), output_ids.getPtr<int>());
// Copy(sequence_lengths_ + i, 1, sequence_length.getPtr<int>());
// if (g.step > g.max_init_ctx_len) { // +1 for newly generated token
// invokePlusScalar(sequence_length.getPtr<int>(), 1, 1, stream_);
// }
// }
// }
}
}
template<typename T>
void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_force_end)
auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Signal
{
if (rank_ == 0) {
TM_LOG_INFO("[CompleteRequest] slot = %d, id = %lu", index, (long)state_->requests[index]->id);
TM_LOG_INFO("[Interrupt] slot = %d, id = %lu", index, (long)state_->requests[index]->id);
}
if (debug_ && rank_ == 0) {
......@@ -1317,45 +1335,46 @@ void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_for
for (const auto& t : tokens) {
ss << " " << t;
}
TM_LOG_INFO("[CompleteRequest] slot %d, tokens [%s]", index, ss.str().c_str());
TM_LOG_INFO("[Interrupt] slot %d, tokens [%s]", index, ss.str().c_str());
}
if (state_->requests[index]->end_flag || is_force_end) {
sequence_manager_->Erase(state_->requests[index]->id);
if (state_->requests[index]->end_flag || force_end) {
// Sequence is ending this round or a stop request is issued to end it
FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id));
}
else {
// account for the last generated token if not a stop request (which doesn't generate)
const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(is_stop_request);
// Account for the last generated token if not a stop request (which doesn't generate)
const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(force_stop);
auto& seq = *state_->sequences[index];
auto& seq = *state_->sequences[index];
// update token IDs
// Update token IDs
seq.tokens.resize(output_len);
const auto output_ids_data = state_->requests[index]->outputs[rank_].at("output_ids").getPtr<int>();
Copy(output_ids_data, output_len, seq.tokens.data());
std::copy_n(output_ids_data, output_len, seq.tokens.data());
// update random states
seq.random_state.resize(sizeof(curandState_t) * 2);
// save random state in host memory
if (auto ptr = (curandState_t*)seq.random_state.data()) {
ptr = Copy(model_->GetTopKState(index), 1, ptr);
ptr = Copy(model_->GetTopPState(index), 1, ptr);
}
check_cuda_error(cudaStreamSynchronize(stream_));
// Save random state in host memory
seq.random_state.resize(sizeof(curandState_t));
// This async copy must be synchronized by the caller
Copy(state_->curand_state + index, 1, (curandState_t*)seq.random_state.data());
// Set unlock flag for corresponding blocks, will be unlocked in the next `Materialize()`
sequence_manager_->UpdateAndSetUnlock(seq);
}
state_->sequences[index] = nullptr;
// move the request handle into the signal
return [this, r = std::move(state_->requests[index])] {
if (rank_ == 0) {
r->signal.set_value(0);
}
};
}
template<typename T>
void LlamaBatch<T>::InternalThreadEntry(int device_id)
{
TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_);
// TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_);
check_cuda_error(cudaSetDevice(device_id));
auto& shared_state = model_->shared_state_;
......@@ -1364,20 +1383,26 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
auto& infer_requests = shared_state->infer_requests;
auto& stop_requests = shared_state->stop_requests;
// sequences that are removed but still counted in state's size
int finished_count = 0;
GenerationState g{};
constexpr int request_interval = 1;
long request_counter = 0;
while (1) {
if (rank_ == 0) {
const int free_slot_count = max_batch_size_ - state_->size + finished_count;
const bool is_empty = (free_slot_count == max_batch_size_);
// will block if batch is empty
request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state->abort);
if (!shared_state->abort) {
RejectInvalidRequests(stop_requests, infer_requests);
stop_requests.clear();
infer_requests.clear();
if (is_empty || request_counter % request_interval == 0) {
// Block if batch is empty
request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state->abort);
if (!shared_state->abort) {
RejectInvalidRequests(stop_requests, infer_requests);
}
}
}
......@@ -1388,20 +1413,19 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
if (shared_state->abort) {
TM_LOG_INFO("[InternalThreadEntry] stop requested.");
// if (state_->size && rank_ == 0) {
// TM_LOG_WARNING("Active request(s) present (%d) while exiting.", state_->size);
// }
return;
}
auto signals = ProcessStopRequests(stop_requests);
BarrierSignalRequests(*shared_state->barrier, signals);
// Shared `priority` field will be assigned by rank-0
ProcessInferRequests(infer_requests);
// wait while shared stop/infer_requests is being used
// Wait while shared `requests` is being used
shared_state->barrier->wait();
SendSignals(std::move(signals));
auto modified = Initialize();
// finished sequences is handled by `Initialize()`
finished_count = 0;
......@@ -1418,31 +1442,42 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
break;
}
}
auto signals = Finish(g);
finished_count = signals.size();
BarrierSignalRequests(*shared_state->barrier, signals);
if (auto signals = Finish(g, finished_count); !signals.empty()) {
if (finished_count) {
// Finished requests and corresponding output tensors will be released when notified
// wait for all ranks to ensure no rank (except for output thread) will access related
// resources
shared_state->barrier->wait();
}
SendSignals(std::move(signals));
}
}
++request_counter;
}
FT_CHECK(0);
}
template<typename T>
void LlamaBatch<T>::BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals)
void LlamaBatch<T>::SendSignals(std::vector<Signal> signals)
{
if (!signals.empty()) {
barrier.wait();
if (rank_ == 0) {
std::for_each(signals.cbegin(), signals.cend(), [](auto& s) { s(); });
}
barrier.wait();
if (rank_ != 0 || signals.empty()) {
return;
}
{
std::lock_guard lock{output_mutex_};
output_signals_.insert(output_signals_.end(), //
std::move_iterator{signals.begin()},
std::move_iterator{signals.end()});
}
output_cv_.notify_one();
}
template<typename T>
void LlamaBatch<T>::Start()
{
TM_LOG_ERROR("LlamaBatch<T>::Start()");
TM_LOG_INFO("LlamaBatch<T>::Start()");
int device_id = -1;
check_cuda_error(cudaGetDevice(&device_id));
internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this, device_id);
......@@ -1455,36 +1490,27 @@ template<typename T>
void LlamaBatch<T>::OutputThreadEntry()
{
while (true) {
std::vector<Signal> signals;
{
// wait for requests with stream cbs
// Wait for signals to come
std::unique_lock lock(output_mutex_);
output_cv_.wait(lock, [&] { return !output_reqs_.empty() || output_stop_token_; });
// NvtxScope _("output_callback");
// stop requested
output_cv_.wait(lock, [&] { return !output_signals_.empty() || output_stop_token_; });
if (output_stop_token_) {
TM_LOG_INFO("[OutputThreadEntry] stop requested.");
return;
}
if (rank_ == 0 && model_->ffi_lock_) {
TM_LOG_INFO("acquire GIL");
model_->ffi_lock_(1);
TM_LOG_INFO("acquire GIL success");
}
// invoke stream cbs
for (const auto& r : output_reqs_) {
r->stream_cb(&r->outputs[rank_].get());
}
if (rank_ == 0 && model_->ffi_lock_) {
TM_LOG_INFO("release GIL");
model_->ffi_lock_(0);
TM_LOG_INFO("release GIL success");
}
output_reqs_.clear();
signals = std::move(output_signals_);
}
if (rank_ == 0 && model_->ffi_lock_) {
model_->ffi_lock_(1);
}
// invoke stream cbs & signals
for (const auto& s : signals) {
s();
}
if (rank_ == 0 && model_->ffi_lock_) {
model_->ffi_lock_(0);
}
FT_CHECK(output_reqs_.empty());
// notify infer thread 0
output_cv_.notify_one();
}
}
......
......@@ -3,14 +3,18 @@
#pragma once
// #include "src/turbomind/models/llama/LlamaCacheManager.h"
#include "src/turbomind/layers/sampling_layers/BaseSamplingLayer.h"
#include "src/turbomind/models/llama/Barrier.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include <condition_variable>
#include <mutex>
#include <type_traits>
namespace turbomind {
......@@ -18,9 +22,8 @@ struct BatchState {
int* h_context_length;
bool* h_finished;
void* top_k_curand_state;
void* top_p_curand_state;
int* output_ids; // output ids in [B, S]
curandState_t* curand_state;
int* output_ids; // output ids in [B, S]
float* h_rope_theta;
......@@ -66,16 +69,15 @@ public:
int max_seq_len;
};
void InitializeSampling();
void InitializeSampling();
GenerationState InitializeGeneration();
[[nodiscard]] bool Generate(GenerationState& g);
[[nodiscard]] auto Finish(GenerationState& g) -> std::vector<Signal>;
void CompleteRequest(int index, bool is_stop_request, bool is_force_end);
[[nodiscard]] auto Finish(GenerationState& g, int& finished_count) -> std::vector<Signal>;
void SetOutputTensors(const GenerationState& g);
[[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false);
void
OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
......@@ -88,7 +90,7 @@ public:
~LlamaBatch()
{
TM_LOG_ERROR("~LlamaBatch()");
TM_LOG_INFO("~LlamaBatch()");
model_->shared_state_->request_queue.close();
internal_thread_.join();
......@@ -112,15 +114,9 @@ private:
void OutputThreadEntry();
void UpdateSequenceStates(BatchState& state, int index);
void CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst);
void SaveRandomState(BatchState& state, int idx);
void LoadRandomState(BatchState& state, int idx);
void CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc);
void BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals);
void SendSignals(std::vector<Signal> signals);
// analogs to `std::copy_n`
template<typename U>
......@@ -137,6 +133,47 @@ private:
return data += count;
}
template<class... Ts>
void IndexedCopyImpl(const int* src_idx, const int* dst_idx, int count, const std::tuple<Ts*, Ts*, int>&... cpys)
{
if (!count) {
return;
}
constexpr int N = sizeof...(Ts);
static_assert((!std::is_same_v<Ts, void> && ...));
std::array<void*, N> src_ptr{std::get<0>(cpys)...};
std::array<void*, N> dst_ptr{std::get<1>(cpys)...};
std::array<int, N> elem_sz{int(sizeof(Ts) * std::get<2>(cpys))...};
invokeIndexedCopy(src_ptr.data(), //
dst_ptr.data(),
elem_sz.data(),
src_idx,
dst_idx,
count,
N,
stream_);
sync_check_cuda_error();
}
template<class... Ts>
void IndexedCopy(const std::vector<int>& src_idx,
const std::vector<int>& dst_idx,
const std::tuple<Ts*, Ts*, int>&... cpys)
{
// has the same size, or one is empty
FT_CHECK(src_idx.size() == dst_idx.size() || (src_idx.empty() ^ dst_idx.empty()));
IndexedCopyImpl(src_idx.empty() ? nullptr : src_idx.data(),
dst_idx.empty() ? nullptr : dst_idx.data(),
std::max(src_idx.size(), dst_idx.size()),
cpys...);
}
template<class... Ts>
void IndexedCopy(int count, const std::tuple<Ts*, Ts*, int>&... cpys)
{
IndexedCopyImpl(nullptr, nullptr, count, cpys...);
}
private:
const int max_batch_size_;
const int max_context_token_num_;
......@@ -186,9 +223,10 @@ private:
// used by dynamic decoder
int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step`
int* end_ids_buf_{};
bool* finished_buf_{};
uint32_t* seq_limit_len_{};
int* h_end_ids_buf_{};
int* d_end_ids_buf_{};
int** request_output_ids_ptrs_{};
int* request_output_ids_lens_{};
......@@ -205,13 +243,20 @@ private:
uintptr_t* h_k_block_ptrs_{};
uintptr_t* h_v_block_ptrs_{};
int* stop_words_buf_{}; // [batch_size, 2, kMaxStopWordsLen]
int* bad_words_buf_{};
int* h_runtime_top_k_{};
float* h_runtime_top_p_{};
float* h_temperature_{};
float* h_repetition_penalty_{};
uint64_t* h_random_seed_{};
int* h_runtime_top_k_{};
float* h_runtime_top_p_{};
float* h_temperature_{};
float* h_repetition_penalty_{};
int* h_stop_words_{}; // [batch_size, 2, kMaxStopWordsLen]
int* h_bad_words_{};
int* d_stop_words_{}; // [batch_size, 2, kMaxStopWordsLen]
int* d_bad_words_{};
unsigned long long* h_random_seed_{};
unsigned long long* d_random_seed_{};
curandState_t* h_curand_state_{};
curandState_t* d_curand_state_{};
std::array<BatchState, 3> states_{};
......@@ -232,7 +277,7 @@ private:
TensorMap inputs_;
TensorMap outputs_;
std::unordered_map<std::string, void*> sampling_params_;
std::vector<std::tuple<std::string, std::byte*, std::byte*>> sampling_params_;
cudaStream_t stream_{};
cublasMMWrapper* cublas_wrapper_{};
......@@ -244,8 +289,10 @@ private:
std::thread output_thread_;
std::mutex output_mutex_;
std::condition_variable output_cv_;
Requests output_reqs_;
std::vector<Signal> output_signals_;
bool output_stop_token_{false};
int* h_output_ids_{};
};
} // namespace turbomind
......@@ -45,37 +45,37 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
// no padding
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, true);
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, false);
// padding is rebuilt for q/k/v_buf_2_
// [qH + 2kvH, B, S, D]
q_buf_2_ = (T*)allocator_->reMalloc(
q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, true);
q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, false);
k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_;
v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_;
if (use_fmha_) {
FlashAttentionOp<T> flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
if (flash_attention.get_workspace_size() > 0) {
qk_buf_float_ = (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), true);
qk_buf_float_ = (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), false);
}
}
else {
// kv heads are repeated for unfused attention
k_cache_buf_ = (T*)allocator_->reMalloc(
k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true);
k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, false);
v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_;
qk_buf_ =
(T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true);
(T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, false);
// qkv_buf_2_ has padding
qkv_buf_2_ = (T*)allocator_->reMalloc(
qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, true);
qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, false);
}
// qkv_buf_3_ padding is removed
qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, true);
qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, false);
is_allocate_buffer_ = true;
}
......
......@@ -45,7 +45,7 @@ void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size)
reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));
workspace_ = (float*)allocator_->reMalloc(
workspace_, sizeof(float) * batch_size * local_head_num_ * kMaxSplitK * (size_per_head_ + 2));
workspace_, sizeof(float) * batch_size * local_head_num_ * kMaxSplitK * (size_per_head_ + 2), false);
is_allocate_buffer_ = true;
}
......
......@@ -102,10 +102,6 @@ LlamaV2<T>::LlamaV2(size_t head_num,
size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) {
elem_bits = sizeof(int8_t) * 8;
if (use_context_fmha) {
TM_LOG_ERROR("use_context_fmha not support int8");
assert(0);
}
}
else {
elem_bits = sizeof(T) * 8;
......@@ -406,6 +402,7 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
bool* finished,
int* sequence_length,
bool* should_stop,
curandState_t* curand_state,
TensorMap* inputs,
TensorMap* outputs,
const float* logits,
......@@ -450,7 +447,8 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
{"output_ids", {MEMORY_GPU, TYPE_INT32, {token_ids_len, batch_size, 1U}, token_ids}},
{"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
{"sequence_length", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
{"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}}};
{"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}},
{"curand_state", {MEMORY_GPU, TYPE_VOID, {batch_size}, curand_state}}};
const std::vector<std::string> optional_outputs{"cum_log_probs", "output_log_probs"};
for (const auto& key : optional_outputs) {
......@@ -562,7 +560,7 @@ void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>* outputs,
if (ec) {
has_error = true;
}
TM_LOG_INFO("[forward] Request complete for %ld, ec = %d", (long)ids[i], (int)ec);
TM_LOG_INFO("[forward] Request complete for %ld, code %d", (long)ids[i], (int)ec);
}
}
......
......@@ -151,6 +151,7 @@ private:
bool* finished,
int* sequence_length,
bool* should_stop,
curandState_t* curand_state,
TensorMap* inputs,
TensorMap* outputs,
const float* logits,
......@@ -163,16 +164,6 @@ private:
size_t token_ids_len,
size_t batch_size);
curandState_t* GetTopKState(int index)
{
return dynamic_decode_layer_->topk_curandstate_buf() + index;
}
curandState_t* GetTopPState(int index)
{
return dynamic_decode_layer_->topp_curandstate_buf() + index;
}
private:
friend class LlamaBatch<T>;
......
......@@ -87,11 +87,8 @@ bool SequenceManager::Erase(uint64_t id)
}
}
sequences_.erase(it);
return true;
}
else {
throw std::out_of_range(std::to_string(id));
}
return false;
}
......
......@@ -58,13 +58,13 @@ public:
SequenceManager(const SequenceManager&) = delete;
SequenceManager(SequenceManager&&) noexcept = default;
const Sequence* Create(uint64_t id);
[[nodiscard]] const Sequence* Create(uint64_t id);
const Sequence* Get(uint64_t id);
[[nodiscard]] const Sequence* Get(uint64_t id);
bool Contains(uint64_t id);
[[nodiscard]] bool Contains(uint64_t id);
bool Erase(uint64_t id);
[[nodiscard]] bool Erase(uint64_t id);
void UpdateAndSetUnlock(const Sequence& seq);
......@@ -74,10 +74,10 @@ public:
int swap_out;
};
Outcome Materialize(Sequences sequences,
std::vector<int> context_lengths,
const std::vector<uint64_t>& priorities,
int step_length);
[[nodiscard]] Outcome Materialize(Sequences sequences,
std::vector<int> context_lengths,
const std::vector<uint64_t>& priorities,
int step_length);
void* OffsetKey(void* block_ptr)
{
......
......@@ -8,7 +8,11 @@
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include <algorithm>
#include <cstdint>
#include <cub/block/block_reduce.cuh>
#include <type_traits>
namespace turbomind {
......@@ -606,6 +610,173 @@ void invokeUpdateOutput(int** request_output_ids_ptrs,
token_generated);
}
template<int BLOCK_DIM>
__global__ void compactOutputIds(
int* cu_output_ids, const int* output_ids, const int* sequence_lengths, int session_len, bool token_generated)
{
typedef cub::BlockReduce<int, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
const int batch_idx = blockIdx.x;
int end = (batch_idx + BLOCK_DIM - 1) / BLOCK_DIM * BLOCK_DIM; // align to BLOCK_DIM boundary
int count = 0;
for (int i = threadIdx.x; i < end; i += blockDim.x) {
int x = threadIdx.x < batch_idx ? sequence_lengths[threadIdx.x] : 0;
count += BlockReduce(temp_storage).Sum(x);
// https://nvlabs.github.io/cub/classcub_1_1_block_reduce.html
__syncthreads();
}
__shared__ int offset;
if (threadIdx.x == 0) {
offset = count;
}
__syncthreads();
auto dst = cu_output_ids + offset;
const int seq_len = sequence_lengths[batch_idx];
for (int i = threadIdx.x; i < seq_len; i += blockDim.x) {
dst[i] = output_ids[batch_idx * session_len + i];
}
}
void invokeCompactOutputIds(int* cu_output_ids,
const int* output_ids,
const int* sequence_lengths,
int max_session_len,
bool token_generated,
int batch_size,
cudaStream_t stream)
{
constexpr int BLOCK_DIM = 128;
compactOutputIds<BLOCK_DIM><<<batch_size, BLOCK_DIM, 0, stream>>>(
cu_output_ids, output_ids, sequence_lengths, max_session_len, token_generated);
}
template<int N, int C>
struct IndexedCopyParam {
Array<void*, N> src_ptr;
Array<void*, N> dst_ptr;
Array<int, N> stride;
Array<int, C> src_idx;
Array<int, C> dst_idx;
int max_stride;
};
template<class T, int N, int C>
__global__ void indexedCopy(IndexedCopyParam<N, C> param)
{
const int bi = blockIdx.x;
const int si = param.src_idx[bi];
const int di = param.dst_idx[bi];
for (int i = threadIdx.x; i < param.max_stride; i += blockDim.x) {
PRAGMA_UNROLL
for (int k = 0; k < N; ++k) {
if (i < param.stride[k]) {
*((T*)param.dst_ptr[k] + param.stride[k] * di + i) =
*((const T*)param.src_ptr[k] + param.stride[k] * si + i);
}
}
}
}
template<class T, int N>
void invokeIndexedCopyImpl(void** h_src_ptr,
void** h_dst_ptr,
const int* h_elem_sz,
const int* h_src_idx,
const int* h_dst_idx,
int count,
cudaStream_t st)
{
auto invoke = [&](auto max_count) {
constexpr int C = decltype(max_count)::value;
// maximum parameter size: sm<70: 4kB, sm>=70: 32kB
static_assert(sizeof(IndexedCopyParam<N, C>) <= 4096);
IndexedCopyParam<N, C> param{};
std::copy_n(h_src_ptr, N, param.src_ptr.data());
std::copy_n(h_dst_ptr, N, param.dst_ptr.data());
std::transform(h_elem_sz, h_elem_sz + N, param.stride.data(), [](int size) {
// Basic alignment check
FT_CHECK_WITH_INFO(size % sizeof(T) == 0, fmtstr("misalignment: %d %% %d", size, (int)sizeof(T)));
return size / sizeof(T);
});
param.max_stride = *std::max_element(param.stride.begin(), param.stride.end());
auto copy_idx = [](const int* src, int offset, int n, auto dst) {
return src ? (void)std::copy_n(src + offset, n, dst) : std::iota(dst, dst + n, offset);
};
for (int c = 0; c < count; c += C) {
int batch_size = std::min(count - c, C);
copy_idx(h_src_idx, c, batch_size, param.src_idx.data());
copy_idx(h_dst_idx, c, batch_size, param.dst_idx.data());
indexedCopy<T><<<batch_size, 128, 0, st>>>(param);
}
};
if (count <= 4) {
invoke(std::integral_constant<int, 4>{});
}
if (count <= 8) {
invoke(std::integral_constant<int, 8>{});
}
else if (count <= 16) {
invoke(std::integral_constant<int, 16>{});
}
else if (count <= 32) {
invoke(std::integral_constant<int, 32>{});
}
else if (count <= 64) {
invoke(std::integral_constant<int, 64>{});
}
else if (count <= 128) {
invoke(std::integral_constant<int, 128>{});
}
else {
invoke(std::integral_constant<int, 256>{});
}
}
void invokeIndexedCopy(void** h_src_ptr,
void** h_dst_ptr,
const int* h_elem_sz,
const int* h_src_idx,
const int* h_dst_idx,
int count,
int n_copys,
cudaStream_t st)
{
auto args = std::tuple{h_src_ptr, h_dst_ptr, h_elem_sz, h_src_idx, h_dst_idx, count, st};
switch (n_copys) {
case 1:
return std::apply(invokeIndexedCopyImpl<uint32_t, 1>, args);
case 2:
return std::apply(invokeIndexedCopyImpl<uint32_t, 2>, args);
case 3:
return std::apply(invokeIndexedCopyImpl<uint32_t, 3>, args);
case 4:
return std::apply(invokeIndexedCopyImpl<uint32_t, 4>, args);
default:
FT_CHECK(0);
}
}
__global__ void padLastTokenIds(int* token_ids, const int* context_length, int max_context_len, int batch_size)
{
for (int bi = threadIdx.x; bi < batch_size; bi += blockDim.x) {
token_ids[(max_context_len - 1) * batch_size + bi] = token_ids[(context_length[bi] - 1) * batch_size + bi];
}
}
void invokePadLastTokenIds(
int* token_ids, const int* context_length, int max_context_len, int batch_size, cudaStream_t stream)
{
padLastTokenIds<<<1, 512, 0, stream>>>(token_ids, context_length, max_context_len, batch_size);
}
#define VERSION_SWITCH(VERSION, CONST_NAME, ...) \
[&] { \
if (VERSION == 2) { \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment