Unverified Commit 911c0a85 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Optimize for throughput (#701)



* tmp

* update

* update

* optimize for throughput

* update

* fix eos

* clean up

* fix serving

* fix indexed copy

* minor

* minor

---------
Co-authored-by: default avatarlvhan028 <lvhan_028@163.com>
parent 65d735ba
......@@ -448,8 +448,8 @@ int main(int argc, char* argv[])
std::vector<int> hBuf(outCount);
ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount);
ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size);
ft::cudaAutoCpy(hBuf.data(), d_output_ids, outCount);
ft::cudaAutoCpy(seq_lens.data(), d_seq_lens, batch_size);
std::cout << "sequence length: ";
for (int i = 0; i < batch_size; ++i) {
......
......@@ -350,7 +350,7 @@ class TurboMindInstance:
outputs = _tm_dict_to_torch_dict(tm_outputs)
output_ids = outputs['output_ids'][:, 0, :]
sequence_length = outputs['sequence_length'].long()[:, 0].cpu()
sequence_length = outputs['sequence_length'].long()[:, 0]
output_ids = [
output_id[s:l] for output_id, s, l in zip(
output_ids, seq_start, sequence_length)
......@@ -366,7 +366,6 @@ class TurboMindInstance:
outputs.append((output[:-1], len_))
else:
outputs.append((output, len_))
yield outputs
if finish:
......
......@@ -236,15 +236,88 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
template<typename T, int N>
struct Array {
T a[N];
__device__ __host__ constexpr T& operator[](int i) noexcept
using value_type = T;
using size_type = int;
using difference_type = int;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using iterator = pointer;
using const_iterator = const_pointer;
static_assert(N > 0);
T __a[N];
__device__ __host__ constexpr reference operator[](size_type i) noexcept
{
return __a[i];
}
__device__ __host__ constexpr const_reference operator[](size_type i) const noexcept
{
return __a[i];
}
__device__ __host__ constexpr reference front() noexcept
{
return *begin();
}
__device__ __host__ constexpr const_reference front() const noexcept
{
return *begin();
}
__device__ __host__ constexpr reference back() noexcept
{
return *(end() - 1);
}
__device__ __host__ constexpr const_reference back() const noexcept
{
return *(end() - 1);
}
__device__ __host__ constexpr pointer data() noexcept
{
return a[i];
return &__a[0];
}
__device__ __host__ constexpr const T& operator[](int i) const noexcept
__device__ __host__ constexpr const_pointer data() const noexcept
{
return &__a[0];
}
__device__ __host__ constexpr iterator begin() noexcept
{
return data();
}
__device__ __host__ constexpr const_iterator begin() const noexcept
{
return data();
}
__device__ __host__ constexpr iterator end() noexcept
{
return data() + N;
}
__device__ __host__ constexpr const_iterator end() const noexcept
{
return data() + N;
}
__device__ __host__ constexpr std::integral_constant<int, N> size() const noexcept
{
return {};
}
__device__ __host__ constexpr std::false_type empty() const noexcept
{
return a[i];
return {};
}
};
......
......@@ -188,6 +188,7 @@ void DynamicDecodeLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_
*
* output_tensors:
* \param output_ids [max_seq_len, batch_size]
* \param curand_state [local_batch_size]
* \param finished [batch_size * beam_width], optional
* \param should_stop [1] on cpu
* \param cum_log_probs [batch_size * beam_width], necessary in beam search
......@@ -276,7 +277,8 @@ void DynamicDecodeLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_
{"input_lengths", input_lengths.slice({local_batch_size, beam_width}, local_batch_offset)});
}
TensorMap decode_output_tensors({{"output_ids", output_tensors->at("output_ids")}});
TensorMap decode_output_tensors({{"output_ids", output_tensors->at("output_ids")}, //
{"curand_state", output_tensors->at("curand_state")}});
if (output_tensors->isExist("sequence_length")) {
Tensor sequence_length = output_tensors->at("sequence_length");
decode_output_tensors.insert(
......
......@@ -53,15 +53,6 @@ protected:
int* h_pinned_finished_sum_ = nullptr;
public:
curandState_t* topk_curandstate_buf()
{
return static_cast<BaseSamplingLayer<T>*>(topk_decode_)->curandstate_buf();
}
curandState_t* topp_curandstate_buf()
{
return static_cast<BaseSamplingLayer<T>*>(topp_decode_)->curandstate_buf();
}
DynamicDecodeLayer(size_t vocab_size,
size_t vocab_size_padded,
int end_id,
......
......@@ -30,10 +30,6 @@ template<typename T>
void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
curandstate_buf_ = reinterpret_cast<curandState_t*>(
allocator_->reMalloc(curandstate_buf_, sizeof(curandState_t) * batch_size, false));
random_seeds_buf_ = reinterpret_cast<unsigned long long*>(
allocator_->reMalloc(random_seeds_buf_, sizeof(unsigned long long) * batch_size, false));
temperature_buf_ =
reinterpret_cast<float*>(allocator_->reMalloc(temperature_buf_, sizeof(float) * batch_size, false));
repetition_penalty_buf_ =
......@@ -58,8 +54,6 @@ void BaseSamplingLayer<T>::freeBuffer()
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) {
allocator_->free((void**)(&curandstate_buf_));
allocator_->free((void**)(&random_seeds_buf_));
allocator_->free((void**)(&temperature_buf_));
allocator_->free((void**)(&repetition_penalty_buf_));
allocator_->free((void**)(&min_lengths_buf_));
......@@ -128,32 +122,6 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
Tensor runtime_top_p = runtime_args->isExist("runtime_top_p") ? runtime_args->at("runtime_top_p") : Tensor();
allocateBuffer(batch_size, runtime_top_k, runtime_top_p);
// If runtime argument has single random seed, using this random seed to initialize the random table of all
// sentences. If the argument has [batch_size] random seeds, initializing the random table by different random seeds
// respectively. If no random seed, initialize the random table of all sentences by 0 directly.
if (runtime_args->isExist("random_seed")) {
Tensor random_seeds = runtime_args->at("random_seed");
FT_CHECK_WITH_INFO(random_seeds.shape.size() == 1
&& (random_seeds.size() == 1 || random_seeds.size() == batch_size),
fmtstr("random_seeds must be of shape [1] or [batch_size(%ld)], got random_seeds.shape=%s",
batch_size,
vec2str(random_seeds.shape).c_str()));
if (random_seeds.size() == 1) {
invokeCurandInitialize(curandstate_buf_, batch_size, random_seeds.getVal<unsigned long long>(), stream_);
sync_check_cuda_error();
}
else {
unsigned long long* random_seed_ptr = random_seeds.getPtr<unsigned long long>();
cudaAutoCpy(random_seeds_buf_, random_seed_ptr, batch_size, stream_);
invokeCurandBatchInitialize(curandstate_buf_, batch_size, random_seeds_buf_, stream_);
sync_check_cuda_error();
}
}
else {
// Initialize curand states using the default seed 0.
invokeCurandInitialize(curandstate_buf_, batch_size, 0, stream_);
}
// Setup penalties.
const float default_temperature = 1.0f;
Tensor temperature = runtime_args->isExist("temperature") ?
......
......@@ -35,8 +35,6 @@ protected:
size_t sampling_workspace_size_;
void* sampling_workspace_ = nullptr;
curandState_t* curandstate_buf_ = nullptr;
unsigned long long* random_seeds_buf_ = nullptr;
float* temperature_buf_ = nullptr;
float* repetition_penalty_buf_ = nullptr;
......@@ -59,11 +57,6 @@ protected:
virtual void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p);
public:
curandState_t* curandstate_buf()
{
return curandstate_buf_;
}
BaseSamplingLayer(size_t max_batch_size,
size_t vocab_size,
size_t vocab_size_padded,
......
......@@ -16,6 +16,7 @@
*/
#include <float.h>
#include <sstream>
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/kernels/sampling_topp_kernels.h"
......@@ -199,6 +200,7 @@ void TopKSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
// output_tensors:
// output_ids [max_seq_len, batch_size]
// curand_state [local_batch_size]
// finished [local_batch_size], optional
// sequence_length [local_batch_size], optional
// cum_log_probs [batch_size], must be float*, optional
......@@ -255,7 +257,7 @@ void TopKSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
output_tensors->at("finished", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr<bool>(),
cum_log_probs,
output_log_probs,
curandstate_buf_ + ite * local_batch_size,
output_tensors->at("curand_state").getPtr<curandState_t>() + ite * local_batch_size,
(int)runtime_max_top_k_, // useless because runtime_top_k_buf_ is never nullptr. Keep for legacy.
(int*)(runtime_top_k_buf_ + ite * local_batch_size),
1.0f, // useless because runtime_top_p_buf_ is never nullptr. Keep for legacy.
......
......@@ -40,8 +40,6 @@ private:
using BaseSamplingLayer<T>::sampling_workspace_size_;
using BaseSamplingLayer<T>::sampling_workspace_;
using BaseSamplingLayer<T>::curandstate_buf_;
using BaseSamplingLayer<T>::random_seeds_buf_;
using BaseSamplingLayer<T>::skip_decode_buf_;
using BaseSamplingLayer<T>::skip_decode_;
using BaseSamplingLayer<T>::skip_any_;
......
......@@ -132,7 +132,7 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
topp_id_vals_buf_,
topp_offset_buf_,
begin_topp_offset_buf_,
curandstate_buf_,
nullptr, // not used when workspace is null
batch_size,
vocab_size_padded_,
nullptr,
......@@ -267,6 +267,7 @@ void TopPSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
* output_tensors:
* \param output_ids [max_seq_len, batch_size]
* \param curand_state [local_batch_size]
* \param finished [local_batch_size], optional
* \param sequence_length [local_batch_size], optional
* \param cum_log_probs [batch_size], must be float*, optional
......@@ -319,7 +320,7 @@ void TopPSamplingLayer<T>::runSampling(TensorMap* output_tensors, TensorMap* inp
topp_id_vals_buf_,
topp_offset_buf_,
begin_topp_offset_buf_,
curandstate_buf_ + ite * local_batch_size,
output_tensors->at("curand_state").getPtr<curandState_t>() + ite * local_batch_size,
local_batch_size,
vocab_size_padded_,
input_tensors->at("end_id").getPtr<int>(),
......
......@@ -48,8 +48,6 @@ private:
using BaseSamplingLayer<T>::sampling_workspace_size_;
using BaseSamplingLayer<T>::sampling_workspace_;
using BaseSamplingLayer<T>::curandstate_buf_;
using BaseSamplingLayer<T>::random_seeds_buf_;
using BaseSamplingLayer<T>::skip_decode_buf_;
using BaseSamplingLayer<T>::skip_decode_;
using BaseSamplingLayer<T>::skip_any_;
......
This diff is collapsed.
......@@ -3,14 +3,18 @@
#pragma once
// #include "src/turbomind/models/llama/LlamaCacheManager.h"
#include "src/turbomind/layers/sampling_layers/BaseSamplingLayer.h"
#include "src/turbomind/models/llama/Barrier.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include <condition_variable>
#include <mutex>
#include <type_traits>
namespace turbomind {
......@@ -18,8 +22,7 @@ struct BatchState {
int* h_context_length;
bool* h_finished;
void* top_k_curand_state;
void* top_p_curand_state;
curandState_t* curand_state;
int* output_ids; // output ids in [B, S]
float* h_rope_theta;
......@@ -67,15 +70,14 @@ public:
};
void InitializeSampling();
GenerationState InitializeGeneration();
[[nodiscard]] bool Generate(GenerationState& g);
[[nodiscard]] auto Finish(GenerationState& g) -> std::vector<Signal>;
void CompleteRequest(int index, bool is_stop_request, bool is_force_end);
[[nodiscard]] auto Finish(GenerationState& g, int& finished_count) -> std::vector<Signal>;
void SetOutputTensors(const GenerationState& g);
[[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false);
void
OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
......@@ -88,7 +90,7 @@ public:
~LlamaBatch()
{
TM_LOG_ERROR("~LlamaBatch()");
TM_LOG_INFO("~LlamaBatch()");
model_->shared_state_->request_queue.close();
internal_thread_.join();
......@@ -112,15 +114,9 @@ private:
void OutputThreadEntry();
void UpdateSequenceStates(BatchState& state, int index);
void CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst);
void SaveRandomState(BatchState& state, int idx);
void LoadRandomState(BatchState& state, int idx);
void CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc);
void BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals);
void SendSignals(std::vector<Signal> signals);
// analogs to `std::copy_n`
template<typename U>
......@@ -137,6 +133,47 @@ private:
return data += count;
}
template<class... Ts>
void IndexedCopyImpl(const int* src_idx, const int* dst_idx, int count, const std::tuple<Ts*, Ts*, int>&... cpys)
{
if (!count) {
return;
}
constexpr int N = sizeof...(Ts);
static_assert((!std::is_same_v<Ts, void> && ...));
std::array<void*, N> src_ptr{std::get<0>(cpys)...};
std::array<void*, N> dst_ptr{std::get<1>(cpys)...};
std::array<int, N> elem_sz{int(sizeof(Ts) * std::get<2>(cpys))...};
invokeIndexedCopy(src_ptr.data(), //
dst_ptr.data(),
elem_sz.data(),
src_idx,
dst_idx,
count,
N,
stream_);
sync_check_cuda_error();
}
template<class... Ts>
void IndexedCopy(const std::vector<int>& src_idx,
const std::vector<int>& dst_idx,
const std::tuple<Ts*, Ts*, int>&... cpys)
{
// has the same size, or one is empty
FT_CHECK(src_idx.size() == dst_idx.size() || (src_idx.empty() ^ dst_idx.empty()));
IndexedCopyImpl(src_idx.empty() ? nullptr : src_idx.data(),
dst_idx.empty() ? nullptr : dst_idx.data(),
std::max(src_idx.size(), dst_idx.size()),
cpys...);
}
template<class... Ts>
void IndexedCopy(int count, const std::tuple<Ts*, Ts*, int>&... cpys)
{
IndexedCopyImpl(nullptr, nullptr, count, cpys...);
}
private:
const int max_batch_size_;
const int max_context_token_num_;
......@@ -186,9 +223,10 @@ private:
// used by dynamic decoder
int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step`
int* end_ids_buf_{};
bool* finished_buf_{};
uint32_t* seq_limit_len_{};
int* h_end_ids_buf_{};
int* d_end_ids_buf_{};
int** request_output_ids_ptrs_{};
int* request_output_ids_lens_{};
......@@ -205,13 +243,20 @@ private:
uintptr_t* h_k_block_ptrs_{};
uintptr_t* h_v_block_ptrs_{};
int* stop_words_buf_{}; // [batch_size, 2, kMaxStopWordsLen]
int* bad_words_buf_{};
int* h_runtime_top_k_{};
float* h_runtime_top_p_{};
float* h_temperature_{};
float* h_repetition_penalty_{};
uint64_t* h_random_seed_{};
int* h_stop_words_{}; // [batch_size, 2, kMaxStopWordsLen]
int* h_bad_words_{};
int* d_stop_words_{}; // [batch_size, 2, kMaxStopWordsLen]
int* d_bad_words_{};
unsigned long long* h_random_seed_{};
unsigned long long* d_random_seed_{};
curandState_t* h_curand_state_{};
curandState_t* d_curand_state_{};
std::array<BatchState, 3> states_{};
......@@ -232,7 +277,7 @@ private:
TensorMap inputs_;
TensorMap outputs_;
std::unordered_map<std::string, void*> sampling_params_;
std::vector<std::tuple<std::string, std::byte*, std::byte*>> sampling_params_;
cudaStream_t stream_{};
cublasMMWrapper* cublas_wrapper_{};
......@@ -244,8 +289,10 @@ private:
std::thread output_thread_;
std::mutex output_mutex_;
std::condition_variable output_cv_;
Requests output_reqs_;
std::vector<Signal> output_signals_;
bool output_stop_token_{false};
int* h_output_ids_{};
};
} // namespace turbomind
......@@ -45,37 +45,37 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
// no padding
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, true);
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, false);
// padding is rebuilt for q/k/v_buf_2_
// [qH + 2kvH, B, S, D]
q_buf_2_ = (T*)allocator_->reMalloc(
q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, true);
q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, false);
k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_;
v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_;
if (use_fmha_) {
FlashAttentionOp<T> flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
if (flash_attention.get_workspace_size() > 0) {
qk_buf_float_ = (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), true);
qk_buf_float_ = (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), false);
}
}
else {
// kv heads are repeated for unfused attention
k_cache_buf_ = (T*)allocator_->reMalloc(
k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true);
k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, false);
v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_;
qk_buf_ =
(T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true);
(T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, false);
// qkv_buf_2_ has padding
qkv_buf_2_ = (T*)allocator_->reMalloc(
qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, true);
qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, false);
}
// qkv_buf_3_ padding is removed
qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, true);
qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, false);
is_allocate_buffer_ = true;
}
......
......@@ -45,7 +45,7 @@ void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size)
reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));
workspace_ = (float*)allocator_->reMalloc(
workspace_, sizeof(float) * batch_size * local_head_num_ * kMaxSplitK * (size_per_head_ + 2));
workspace_, sizeof(float) * batch_size * local_head_num_ * kMaxSplitK * (size_per_head_ + 2), false);
is_allocate_buffer_ = true;
}
......
......@@ -102,10 +102,6 @@ LlamaV2<T>::LlamaV2(size_t head_num,
size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) {
elem_bits = sizeof(int8_t) * 8;
if (use_context_fmha) {
TM_LOG_ERROR("use_context_fmha not support int8");
assert(0);
}
}
else {
elem_bits = sizeof(T) * 8;
......@@ -406,6 +402,7 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
bool* finished,
int* sequence_length,
bool* should_stop,
curandState_t* curand_state,
TensorMap* inputs,
TensorMap* outputs,
const float* logits,
......@@ -450,7 +447,8 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
{"output_ids", {MEMORY_GPU, TYPE_INT32, {token_ids_len, batch_size, 1U}, token_ids}},
{"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
{"sequence_length", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
{"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}}};
{"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}},
{"curand_state", {MEMORY_GPU, TYPE_VOID, {batch_size}, curand_state}}};
const std::vector<std::string> optional_outputs{"cum_log_probs", "output_log_probs"};
for (const auto& key : optional_outputs) {
......@@ -562,7 +560,7 @@ void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>* outputs,
if (ec) {
has_error = true;
}
TM_LOG_INFO("[forward] Request complete for %ld, ec = %d", (long)ids[i], (int)ec);
TM_LOG_INFO("[forward] Request complete for %ld, code %d", (long)ids[i], (int)ec);
}
}
......
......@@ -151,6 +151,7 @@ private:
bool* finished,
int* sequence_length,
bool* should_stop,
curandState_t* curand_state,
TensorMap* inputs,
TensorMap* outputs,
const float* logits,
......@@ -163,16 +164,6 @@ private:
size_t token_ids_len,
size_t batch_size);
curandState_t* GetTopKState(int index)
{
return dynamic_decode_layer_->topk_curandstate_buf() + index;
}
curandState_t* GetTopPState(int index)
{
return dynamic_decode_layer_->topp_curandstate_buf() + index;
}
private:
friend class LlamaBatch<T>;
......
......@@ -87,11 +87,8 @@ bool SequenceManager::Erase(uint64_t id)
}
}
sequences_.erase(it);
return true;
}
else {
throw std::out_of_range(std::to_string(id));
}
return false;
}
......
......@@ -58,13 +58,13 @@ public:
SequenceManager(const SequenceManager&) = delete;
SequenceManager(SequenceManager&&) noexcept = default;
const Sequence* Create(uint64_t id);
[[nodiscard]] const Sequence* Create(uint64_t id);
const Sequence* Get(uint64_t id);
[[nodiscard]] const Sequence* Get(uint64_t id);
bool Contains(uint64_t id);
[[nodiscard]] bool Contains(uint64_t id);
bool Erase(uint64_t id);
[[nodiscard]] bool Erase(uint64_t id);
void UpdateAndSetUnlock(const Sequence& seq);
......@@ -74,7 +74,7 @@ public:
int swap_out;
};
Outcome Materialize(Sequences sequences,
[[nodiscard]] Outcome Materialize(Sequences sequences,
std::vector<int> context_lengths,
const std::vector<uint64_t>& priorities,
int step_length);
......
......@@ -8,7 +8,11 @@
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include <algorithm>
#include <cstdint>
#include <cub/block/block_reduce.cuh>
#include <type_traits>
namespace turbomind {
......@@ -606,6 +610,173 @@ void invokeUpdateOutput(int** request_output_ids_ptrs,
token_generated);
}
template<int BLOCK_DIM>
__global__ void compactOutputIds(
int* cu_output_ids, const int* output_ids, const int* sequence_lengths, int session_len, bool token_generated)
{
typedef cub::BlockReduce<int, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
const int batch_idx = blockIdx.x;
int end = (batch_idx + BLOCK_DIM - 1) / BLOCK_DIM * BLOCK_DIM; // align to BLOCK_DIM boundary
int count = 0;
for (int i = threadIdx.x; i < end; i += blockDim.x) {
int x = threadIdx.x < batch_idx ? sequence_lengths[threadIdx.x] : 0;
count += BlockReduce(temp_storage).Sum(x);
// https://nvlabs.github.io/cub/classcub_1_1_block_reduce.html
__syncthreads();
}
__shared__ int offset;
if (threadIdx.x == 0) {
offset = count;
}
__syncthreads();
auto dst = cu_output_ids + offset;
const int seq_len = sequence_lengths[batch_idx];
for (int i = threadIdx.x; i < seq_len; i += blockDim.x) {
dst[i] = output_ids[batch_idx * session_len + i];
}
}
void invokeCompactOutputIds(int* cu_output_ids,
const int* output_ids,
const int* sequence_lengths,
int max_session_len,
bool token_generated,
int batch_size,
cudaStream_t stream)
{
constexpr int BLOCK_DIM = 128;
compactOutputIds<BLOCK_DIM><<<batch_size, BLOCK_DIM, 0, stream>>>(
cu_output_ids, output_ids, sequence_lengths, max_session_len, token_generated);
}
template<int N, int C>
struct IndexedCopyParam {
Array<void*, N> src_ptr;
Array<void*, N> dst_ptr;
Array<int, N> stride;
Array<int, C> src_idx;
Array<int, C> dst_idx;
int max_stride;
};
template<class T, int N, int C>
__global__ void indexedCopy(IndexedCopyParam<N, C> param)
{
const int bi = blockIdx.x;
const int si = param.src_idx[bi];
const int di = param.dst_idx[bi];
for (int i = threadIdx.x; i < param.max_stride; i += blockDim.x) {
PRAGMA_UNROLL
for (int k = 0; k < N; ++k) {
if (i < param.stride[k]) {
*((T*)param.dst_ptr[k] + param.stride[k] * di + i) =
*((const T*)param.src_ptr[k] + param.stride[k] * si + i);
}
}
}
}
template<class T, int N>
void invokeIndexedCopyImpl(void** h_src_ptr,
void** h_dst_ptr,
const int* h_elem_sz,
const int* h_src_idx,
const int* h_dst_idx,
int count,
cudaStream_t st)
{
auto invoke = [&](auto max_count) {
constexpr int C = decltype(max_count)::value;
// maximum parameter size: sm<70: 4kB, sm>=70: 32kB
static_assert(sizeof(IndexedCopyParam<N, C>) <= 4096);
IndexedCopyParam<N, C> param{};
std::copy_n(h_src_ptr, N, param.src_ptr.data());
std::copy_n(h_dst_ptr, N, param.dst_ptr.data());
std::transform(h_elem_sz, h_elem_sz + N, param.stride.data(), [](int size) {
// Basic alignment check
FT_CHECK_WITH_INFO(size % sizeof(T) == 0, fmtstr("misalignment: %d %% %d", size, (int)sizeof(T)));
return size / sizeof(T);
});
param.max_stride = *std::max_element(param.stride.begin(), param.stride.end());
auto copy_idx = [](const int* src, int offset, int n, auto dst) {
return src ? (void)std::copy_n(src + offset, n, dst) : std::iota(dst, dst + n, offset);
};
for (int c = 0; c < count; c += C) {
int batch_size = std::min(count - c, C);
copy_idx(h_src_idx, c, batch_size, param.src_idx.data());
copy_idx(h_dst_idx, c, batch_size, param.dst_idx.data());
indexedCopy<T><<<batch_size, 128, 0, st>>>(param);
}
};
if (count <= 4) {
invoke(std::integral_constant<int, 4>{});
}
if (count <= 8) {
invoke(std::integral_constant<int, 8>{});
}
else if (count <= 16) {
invoke(std::integral_constant<int, 16>{});
}
else if (count <= 32) {
invoke(std::integral_constant<int, 32>{});
}
else if (count <= 64) {
invoke(std::integral_constant<int, 64>{});
}
else if (count <= 128) {
invoke(std::integral_constant<int, 128>{});
}
else {
invoke(std::integral_constant<int, 256>{});
}
}
void invokeIndexedCopy(void** h_src_ptr,
void** h_dst_ptr,
const int* h_elem_sz,
const int* h_src_idx,
const int* h_dst_idx,
int count,
int n_copys,
cudaStream_t st)
{
auto args = std::tuple{h_src_ptr, h_dst_ptr, h_elem_sz, h_src_idx, h_dst_idx, count, st};
switch (n_copys) {
case 1:
return std::apply(invokeIndexedCopyImpl<uint32_t, 1>, args);
case 2:
return std::apply(invokeIndexedCopyImpl<uint32_t, 2>, args);
case 3:
return std::apply(invokeIndexedCopyImpl<uint32_t, 3>, args);
case 4:
return std::apply(invokeIndexedCopyImpl<uint32_t, 4>, args);
default:
FT_CHECK(0);
}
}
__global__ void padLastTokenIds(int* token_ids, const int* context_length, int max_context_len, int batch_size)
{
for (int bi = threadIdx.x; bi < batch_size; bi += blockDim.x) {
token_ids[(max_context_len - 1) * batch_size + bi] = token_ids[(context_length[bi] - 1) * batch_size + bi];
}
}
void invokePadLastTokenIds(
int* token_ids, const int* context_length, int max_context_len, int batch_size, cudaStream_t stream)
{
padLastTokenIds<<<1, 512, 0, stream>>>(token_ids, context_length, max_context_len, batch_size);
}
#define VERSION_SWITCH(VERSION, CONST_NAME, ...) \
[&] { \
if (VERSION == 2) { \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment