Unverified Commit ab1767cf authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

TurboMind 2 (#590)

* refresh decoder attention kernel

* block-level kv cache

* `BlockManager` & `SequenceManager`

* update

* update

* update

* update

* rename

* GQA support

* fix context length

* GQA dispatch

* kv8

* tune

* async stream cb

* nvtx

* config parsing

* debug

* optimize output cost

* split-k decoding

* minor

* truncate `session_len` by available blocks

* minor

* license

* fix

* dispatch `cp.async`

* fix linking

* fix

* fix deadlock

* guard input length

* correct start offset

* fix prefill chunking

* fix `cache_block_seq_len` param passing

* fix `block_size` fmtstr

* fix output tokens

* fix batch resizing

* fix masking of finished sequences

* add debug util

* free unused block early

* add ntk scaling and logn scaling

* cmake flags

* fix typo

* w4a16 for sm75

* fix msvc build

* fix msvc build

* fix block verification

* fix msvc build

* use `std::shuffle`

* fix lint

* fix lint

* fix lint

* clear incoming buffer

* clear finished requests

* fix batch initialization

* fix typo

* fix typo

* fix comparison
parent 06125966
......@@ -9,6 +9,23 @@
namespace turbomind {
__inline__ __device__ void
mma_m16n8k8_row_col(Array<float, 4>& d, const Array<half, 4>& a, const Array<half, 2>& b, Array<float, 4>& c)
{
#if TURBOMIND_ARCH_SM75
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
float const* C = reinterpret_cast<float const*>(&c);
float* D = reinterpret_cast<float*>(&d);
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
"{%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
assert(TURBOMIND_ARCH_SM75);
#endif
}
__inline__ __device__ void
mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c)
{
......@@ -22,7 +39,10 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
assert(TURBOMIND_ARCH_SM80);
const Array<half, 4>* _a = (const Array<half, 4>*)&a;
const Array<half, 2>* _b = (const Array<half, 2>*)&b;
mma_m16n8k8_row_col(d, _a[0], _b[0], c);
mma_m16n8k8_row_col(d, _a[1], _b[1], d);
#endif
}
......
......@@ -15,7 +15,7 @@
* limitations under the License.
*/
#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
......@@ -854,19 +854,20 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
T* v_buf,
T* QKV,
const T* __restrict qkv_bias,
const int* padding_offset,
const int* history_length,
const int* input_length,
int batch_size,
int seq_len,
int head_num,
int kv_head_num,
int size_per_head,
int rotary_embedding_dim,
float rotary_embedding_base,
int max_position_embeddings,
bool use_dynamic_ntk,
bool use_logn_attn)
const int* padding_offset,
const int* context_length,
const int* input_length,
const float* rope_theta,
int batch_size,
int seq_len,
int head_num,
int kv_head_num,
int size_per_head,
int rotary_embedding_dim,
float rotary_embedding_base,
int max_position_embeddings,
bool use_dynamic_ntk,
bool use_logn_attn)
{
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
// QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
......@@ -907,12 +908,18 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
Vec_t q, k, v;
Vec_t q_bias, k_bias, v_bias;
using Vec = Array<T, vec_size>;
static_assert(sizeof(Vec_t) == sizeof(Vec));
using namespace ops;
// load Q and apply bias
if (!is_masked) {
q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]);
if (qkv_bias) {
q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
q = mmha::add(q, q_bias);
q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
(Vec&)q = (Vec&)q + (Vec&)q_bias;
}
}
......@@ -921,35 +928,32 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
if (qkv_bias) {
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + k_offset]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + v_offset]);
k = mmha::add(k, k_bias);
v = mmha::add(v, v_bias);
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + k_offset]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + v_offset]);
(Vec&)k = (Vec&)k + (Vec&)k_bias;
(Vec&)v = (Vec&)v + (Vec&)v_bias;
}
}
const int history_len = history_length[batch_idx];
const int context_len = history_len + input_length[batch_idx];
const int context_len = context_length[batch_idx];
const int history_len = context_len - input_length[batch_idx];
const int timestep = history_len + seq_idx;
if (use_dynamic_ntk) {
rotary_embedding_base = mmha::rotary_embedding_get_base(
context_len, max_position_embeddings, rotary_embedding_dim, rotary_embedding_base);
if (rope_theta) {
rotary_embedding_base = rope_theta[batch_idx];
}
// TODO: unused computation on k if GQA is used
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_embedding_base, timestep);
RotaryEmbedding<vec_size> rotary_emb(rotary_embedding_base, rotary_embedding_dim, timestep, {tidx * vec_size, 0});
rotary_emb.apply((Array<T, vec_size>&)q);
if (head_idx < kv_head_num) {
rotary_emb.apply((Array<T, vec_size>&)k);
}
if (use_logn_attn) {
// +1 to convert to context length at the timestep
float logn_scaling = mmha::logn_attn_get_scaling(timestep + 1, max_position_embeddings);
if constexpr (std::is_same_v<T, float>) {
q = mmha::mul<Vec_t, float, Vec_t>(logn_scaling, q);
}
else if constexpr (std::is_same_v<T, half>) {
half tmp = __float2half(logn_scaling);
q = mmha::mul<Vec_t, uint16_t, Vec_t>((uint16_t&)tmp, q);
}
LogNScaling logn_scaling(timestep + 1, max_position_embeddings);
logn_scaling.apply((Array<T, vec_size>&)q);
}
if (!is_masked && !q_buf) { // also skip modifying QKV if q/k/v_buf are present
......@@ -982,8 +986,9 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
QKV, \
qkv_bias, \
padding_offset, \
history_length, \
context_length, \
input_length, \
rope_theta, \
batch_size, \
seq_len, \
head_num, \
......@@ -1002,8 +1007,9 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int* history_length,
const int* context_length,
const int* input_length,
const float* rope_theta,
const int batch_size,
const int seq_len,
const int token_num,
......@@ -1034,6 +1040,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int* padding_offset, \
const int* history_length, \
const int* input_length, \
const float* rope_theta, \
const int batch_size, \
const int seq_len, \
const int token_num, \
......
......@@ -70,8 +70,9 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int* history_length,
const int* context_length,
const int* input_length,
const float* rope_theta,
const int batch_size,
const int seq_len,
const int token_num,
......
......@@ -2,6 +2,7 @@
#pragma once
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#ifndef _MSC_VER
#include <pthread.h>
......
// Copyright (c) OpenMMLab. All rights reserved.
#include "src/turbomind/models/llama/BlockManager.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/logger.h"
#include <algorithm>
#include <iterator>
#include <stdexcept>
namespace turbomind {
BlockManager::BlockManager(size_t block_size, double block_count, int chunk_size, IAllocator* allocator):
block_size_(block_size), allocator_(allocator)
{
if (block_count < 1.) {
max_block_count_ = GetBlockCount(block_size, block_count);
}
else {
max_block_count_ = block_count;
}
if (chunk_size == 0) {
chunk_size_ = static_cast<int>(std::sqrt(max_block_count_));
}
else if (chunk_size < 0) {
chunk_size_ = max_block_count_;
}
else {
chunk_size_ = chunk_size;
}
TM_LOG_INFO("[BlockManager] block_size = %lu MB", (unsigned long)block_size_ >> 20);
TM_LOG_INFO("[BlockManager] max_block_count = %d", max_block_count_);
TM_LOG_INFO("[BlockManager] chunk_size = %d", chunk_size_);
blocks_.reserve(max_block_count_);
active_ids_.reserve(max_block_count_);
cached_ids_.reserve(max_block_count_);
free_ids_.reserve(max_block_count_);
// pre-allocate first chunk
Malloc();
dbg(free_ids_);
}
BlockManager::~BlockManager()
{
for (auto& chunk : chunks_) {
allocator_->free(&chunk);
}
}
bool BlockManager::Malloc()
{
auto chunk_size = std::min<int>(chunk_size_, max_block_count_ - blocks_.size());
if (!chunk_size) {
return false;
}
auto ptr = (std::byte*)allocator_->malloc(block_size_ * chunk_size);
if (!ptr) {
return false;
}
chunks_.push_back(ptr);
for (int i = 0; i < chunk_size; ++i, ptr += block_size_) {
auto& block = blocks_.emplace_back();
block.use_count = 0;
block.ref_count = 0;
block.id = (int)blocks_.size() - 1;
block.timestamp = 0;
block.data = ptr;
free_ids_.push_back(block.id);
}
return true;
}
size_t BlockManager::GetBlockCount(size_t block_size, double ratio)
{
size_t free{};
size_t total{};
check_cuda_error(cudaMemGetInfo(&free, &total));
return static_cast<size_t>(total * ratio) / block_size;
}
void BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst)
{
std::vector<int> src1(src.size() - delta.size());
std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());
src.swap(src1);
std::vector<int> dst1(dst.size() + delta.size());
std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());
dst.swap(dst1);
}
std::vector<const Block*> BlockManager::Allocate(int count)
{
while (free_ids_.size() < count) {
if (!Malloc()) {
throw std::runtime_error("out of memory");
}
}
std::vector<const Block*> ret;
std::vector<int> idxs(count);
for (int i = 0; i < count; ++i) {
int idx = free_ids_[i];
idxs[i] = idx;
auto& block = blocks_[idx];
FT_CHECK(is_free(block));
block.ref_count = 1;
block.use_count = 1;
block.unique_id = unique_id_++;
ret.push_back(&block);
}
Move(free_ids_, idxs, active_ids_);
dbg(free_ids_, active_ids_);
return ret;
}
void BlockManager::Evict(int count)
{
std::vector<int> idxs(cached_ids_);
// get first `count` cached ids according to timestamp
std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) {
return blocks_[i].timestamp < blocks_[j].timestamp;
});
idxs.resize(count);
// sort the retrieved ids
std::sort(idxs.begin(), idxs.end());
// set as free
for (const auto& idx : idxs) {
auto& b = blocks_[idx];
FT_CHECK(is_cached(b));
b.ref_count = 0;
b.unique_id = 0;
b.timestamp = 0;
}
Move(cached_ids_, idxs, free_ids_);
dbg(cached_ids_, free_ids_);
}
int BlockManager::Free(const std::vector<const Block*>& bs)
{
std::vector<int> idxs;
for (const auto& p : bs) {
auto& b = blocks_[p->id];
FT_CHECK(is_cached(b));
if (--b.ref_count == 0) {
b.unique_id = 0;
b.timestamp = 0;
idxs.push_back(b.id);
}
}
std::sort(idxs.begin(), idxs.end());
Move(cached_ids_, idxs, free_ids_);
dbg(cached_ids_, free_ids_);
return idxs.size();
}
int BlockManager::Unlock(const std::vector<const Block*>& bs)
{
std::vector<int> idxs;
for (const auto& p : bs) {
auto& block = blocks_[p->id];
FT_CHECK(is_active(block));
if (--block.use_count == 0) {
idxs.push_back(block.id);
}
}
std::sort(idxs.begin(), idxs.end());
Move(active_ids_, idxs, cached_ids_);
dbg(active_ids_, cached_ids_);
return idxs.size();
}
int BlockManager::Lock(const std::vector<const Block*>& bs)
{
std::vector<int> idxs;
for (const auto& p : bs) {
auto& block = blocks_[p->id];
FT_CHECK(is_cached(block));
if (++block.use_count == 1) {
idxs.push_back(p->id);
}
}
std::sort(idxs.begin(), idxs.end());
Move(cached_ids_, idxs, active_ids_);
// dbg(cached_ids_, active_ids_);
return idxs.size();
}
void BlockManager::Touch(const std::vector<const Block*>& bs)
{
std::for_each(bs.crbegin(), bs.crend(), [this](const Block* p) {
FT_CHECK(is_active(*p));
const_cast<Block*>(p)->timestamp = timestamp_++;
});
}
Snapshot BlockManager::TakeSnapshot()
{
std::vector<int> use_count(blocks_.size());
for (const auto& idx : active_ids_) {
use_count[idx] = blocks_[idx].use_count;
}
return {active_count(), cached_count(), free_count(), std::move(use_count)};
}
std::ostream& operator<<(std::ostream& os, const BlockManager& manager)
{
os << "block_size: " << manager.block_size_ << ", ";
os << "max_block_count: " << manager.max_block_count_ << ", ";
os << "chunk_size: " << manager.chunk_size_ << ", ";
os << "chunks: " << manager.chunks_.size() << ", ";
os << "active_ids: " << manager.active_ids_.size() << ", ";
os << "cached_ids: " << manager.cached_ids_.size() << ", ";
os << "free_ids: " << manager.free_ids_.size() << ", ";
os << "blocks: " << manager.blocks_.size() << ", ";
os << "unique_id: " << manager.unique_id_ << ", ";
os << "timestamp: " << manager.timestamp_ << ", ";
os << "allocator: " << manager.allocator_;
return os;
}
std::ostream& operator<<(std::ostream& os, const Block& block)
{
os << "id=" << block.id << ", use_count=" << block.use_count << ", unique_id=" << block.unique_id
<< ", timestamp=" << block.timestamp << ", data=" << block.data;
return os;
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include <algorithm>
#include <cstdint>
#include <cuda_runtime.h>
#include <iterator>
#include <numeric>
#include <queue>
#include <unordered_map>
#include <vector>
namespace turbomind {
// [L, H, S, D]
// [L, S/x, H, x, D]
struct Block {
int id; // fixed linear id in the pool
int ref_count; // all sequences referencing the block
int use_count; // active sequences using the block
uint64_t unique_id; // unique for every block allocation
uint64_t timestamp;
void* data;
friend std::ostream& operator<<(std::ostream& os, const Block& block);
};
inline bool is_active(const Block& block)
{
return block.ref_count > 0 && block.use_count > 0;
}
inline bool is_cached(const Block& block)
{
return block.ref_count > 0 && block.use_count == 0;
}
inline bool is_free(const Block& block)
{
return block.ref_count == 0 && block.use_count == 0 && block.timestamp == 0;
}
struct Snapshot {
int active;
int cached;
int free;
std::vector<int> use_count;
};
class BlockManager {
public:
explicit BlockManager(size_t block_size, double block_count, int chunk_size, IAllocator* allocator);
~BlockManager();
// free -> active (use_count = 1, ref_count = 1)
[[nodiscard]] std::vector<const Block*> Allocate(int count);
// cached -> active (use_count += 1)
[[maybe_unused]] int Lock(const std::vector<const Block*>& bs);
// active -> cached (use_count -= 1)
[[maybe_unused]] int Unlock(const std::vector<const Block*>& bs);
// cached -> free (ref_count = 0)
void Evict(int count);
// cached -> free (ref_count -= 1)
[[maybe_unused]] int Free(const std::vector<const Block*>& bs);
// increase timestamp in reversed order
void Touch(const std::vector<const Block*>& bs);
Snapshot TakeSnapshot();
int max_block_count() const noexcept
{
return max_block_count_;
}
int active_count() const noexcept
{
return active_ids_.size();
}
int cached_count() const noexcept
{
return cached_ids_.size();
}
int free_count() const noexcept
{
return (max_block_count_ - blocks_.size()) + free_ids_.size();
}
friend std::ostream& operator<<(std::ostream& os, const BlockManager&);
private:
static size_t GetBlockCount(size_t block_size, double ratio);
// move indices between sets
static void Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst);
// allocate a chunk of blocks
bool Malloc();
private:
size_t block_size_;
int max_block_count_{};
int chunk_size_{};
IAllocator* allocator_;
std::vector<void*> chunks_;
std::vector<int> active_ids_;
std::vector<int> cached_ids_;
std::vector<int> free_ids_;
std::vector<Block> blocks_; // < 100k
// uint64_t unique_id_{1UL << 63};
uint64_t unique_id_{1};
uint64_t timestamp_{1};
};
} // namespace turbomind
......@@ -10,6 +10,8 @@ add_library(Llama STATIC
LlamaV2.cc
LlamaBatch.cc
LlamaCacheManager.cc
BlockManager.cc
SequenceManager.cc
LlamaContextDecoder.cc
LlamaContextAttentionLayer.cc
LlamaDecoderSelfAttentionLayer.cc
......@@ -28,6 +30,7 @@ target_link_libraries(Llama PUBLIC CUDA::cudart
DynamicDecodeLayer
activation_kernels
decoder_masked_multihead_attention
decoder_multihead_attention
bert_preprocess_kernels
decoding_kernels
unfused_attention_kernels
......@@ -48,4 +51,11 @@ endif()
add_executable(llama_gemm llama_gemm.cc)
target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils cuda_utils logger)
install(TARGETS llama_gemm DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/bin)
find_package(Catch2 3 QUIET)
if (Catch2_FOUND)
add_executable(test_cache_manager test_cache_manager.cc)
target_link_libraries(test_cache_manager PRIVATE Llama Catch2::Catch2WithMain)
endif ()
......@@ -6,61 +6,86 @@
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaV2.h"
#include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/gemm_test/gemm_func.h"
#include "src/turbomind/utils/logger.h"
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <iomanip>
#include <mutex>
#include <numeric>
#include <sstream>
#include <unordered_map>
namespace turbomind {
void ClearState(BatchState& s)
{
std::fill_n(s.requests.begin(), s.size, nullptr);
std::fill_n(s.sequences.begin(), s.size, nullptr);
s.size = s.active_size = 0;
}
template<typename T>
void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_reqs,
std::vector<std::shared_ptr<Request>>& infer_reqs)
void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
{
std::unordered_map<uint64_t, int> occurrence;
auto count_occurrence = [&occurrence](const std::vector<std::shared_ptr<Request>>& rs) {
auto count_occurrence = [&occurrence](const Requests& rs) {
for (const auto& r : rs) {
++occurrence[r->id];
}
};
auto invalidate = [](const char* type, std::shared_ptr<Request>& req, int ec) {
TM_LOG_WARNING("[verifyRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
// We don't need a barrier there because
// this lambda is called only for new requests
// which are visible only for rank = 0 thread.
auto reject = [](const char* type, std::shared_ptr<Request>& req, int ec) {
TM_LOG_WARNING(
"[RejectInvalidRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
req->signal.set_value(ec);
req.reset();
};
auto handle_conflict_or_invalid = [this, &occurrence, &invalidate](std::vector<std::shared_ptr<Request>>& rs,
const char* type) {
auto handle_conflict_or_invalid = [this, &occurrence, &reject](Requests& rs, const char* type) {
for (auto& r : rs) {
if (r) {
int ec = 0;
const int input_length = r->inputs[rank_].getVal<int>("input_lengths", 0);
const auto get_offset = [&](int token_count) {
return std::max(0, std::min(token_count, r->inputs[rank_].getVal<int>("step", token_count)));
};
if (occurrence[r->id] != 1) {
ec = Request::kConflict;
}
else if (r->start_flag && r->stop_flag) {
ec = Request::kInvalid;
}
else if (!r->start_flag && !llama_->kv_cache_mgr_->contains(r->id)) {
ec = Request::kInvalid;
else if (input_length > session_len_) {
ec = Request::kTooLong;
}
else if (!r->start_flag) {
if (auto seq = sequence_manager_->Get(r->id); seq == nullptr) {
ec = Request::kInvalid;
}
else if (get_offset(seq->tokens.size()) + input_length > session_len_) {
ec = Request::kTooLong;
}
}
if (ec) {
invalidate(type, r, ec);
reject(type, r, ec);
}
}
}
};
auto drop_invalid = [](std::vector<std::shared_ptr<Request>>& rs) {
auto drop_invalid = [](Requests& rs) {
int count = 0;
for (int i = 0; i < rs.size(); ++i) {
if (rs[i]) {
......@@ -80,14 +105,14 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
for (auto& r : stop_reqs) {
if (r && r->end_flag == false) {
int ec = Request::kInactive;
for (int i = 0; i < batch_size_; ++i) {
if (requests_[i] && requests_[i]->id == r->id) {
for (int i = 0; i < state_->size; ++i) {
if (state_->requests[i] && state_->requests[i]->id == r->id) {
ec = 0;
break;
}
}
if (ec) {
invalidate("stop", r, ec);
reject("stop", r, ec);
}
}
}
......@@ -101,9 +126,9 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
// invalidate requests for busy sequences
for (auto& r : infer_reqs) {
if (r) {
for (int i = 0; i < batch_size_; ++i) {
if (requests_[i] && requests_[i]->id == r->id) {
invalidate("infer", r, Request::kBusy);
for (int i = 0; i < state_->size; ++i) {
if (state_->requests[i] && state_->requests[i]->id == r->id) {
reject("infer", r, Request::kBusy);
break;
}
}
......@@ -115,53 +140,355 @@ void LlamaBatch<T>::verifyRequests(std::vector<std::shared_ptr<Request>>& stop_r
}
template<typename T>
void LlamaBatch<T>::handleStopRequests(const std::vector<std::shared_ptr<Request>>& requests)
auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector<Signal>
{
std::vector<Signal> signals;
for (const auto& r : requests) {
int ec = Request::kFail;
// find matching active sequence
for (int i = 0; i < batch_size_; ++i) {
for (int i = 0; i < state_->size; ++i) {
// stop & optionally erase active sequence
if (requests_[i] && requests_[i]->id == r->id) {
if (state_->requests[i] && state_->requests[i]->id == r->id) {
ec = 0;
finishRequest(i, r->end_flag);
CompleteRequest(i, true, r->end_flag);
state_->requests[i].reset();
break;
}
}
// mismatch, try erase inactive sequence
// mismatch, try erase inactive sequence, in this case there is no active request to finish
if (ec && r->end_flag) {
ec = 0;
llama_->kv_cache_mgr_->erase(r->id);
sequence_manager_->Erase(r->id);
}
// clear output buffers (prevent leaking conversations) if request is successful
if (ec == 0) {
if (rank_ == 0) {
std::unique_lock lock{output_mutex_};
output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
}
auto& output_ids = r->outputs[rank_].at("output_ids");
auto& sequence_length = r->outputs[rank_].at("sequence_length");
check_cuda_error(
cudaMemsetAsync(output_ids.getPtr<int>(), 0, sizeof(int) * output_ids.shape.at(2), stream_));
check_cuda_error(cudaMemsetAsync(sequence_length.getPtr<int>(), 0, sizeof(int), stream_));
Clear(output_ids.getPtr<int>(), output_ids.shape.at(2));
Clear(sequence_length.getPtr<int>(), 1);
check_cuda_error(cudaStreamSynchronize(stream_));
}
signals.push_back([=] { r->signal.set_value(ec); });
}
return signals;
}
// When the signal is set threads from LlamaV2::forward can exit
// and free inputs/outputs tensors.
// Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
// are accessing the tensors.
llama_->shared_state_->barrier->wait();
if (rank_ == 0) {
r->signal.set_value(ec);
template<typename T>
void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
{
auto& state = *incoming_;
FT_CHECK(state.size == 0);
FT_CHECK(state.active_size == 0);
int i = 0;
for (const auto& r : requests) {
// sanity check, incoming request in previous iter should have been moved to `state_`
FT_CHECK(!state.requests[i]);
TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id);
state.requests[i] = r;
// get sequence for the request
state.sequences[i] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);
auto& seq = *state.sequences[i];
if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
/// TODO: revise step setting
if (step <= seq.tokens.size()) {
seq.tokens.resize(step);
seq.cache_len = std::min(seq.cache_len, step);
}
else if (rank_ == 0) {
TM_LOG_WARNING(
"[ProcessInferRequests] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id);
}
}
const int input_length = r->inputs[rank_].getVal<int>("input_lengths");
const int* input_ids = r->inputs[rank_].getPtr<int>("input_ids");
// `output_ids` contains all token ids of the sequences
const auto output_ids_base = state.output_ids + session_len_ * i;
auto output_ids = output_ids_base;
// copy history tokens
if (!seq.tokens.empty()) {
output_ids = Copy(seq.tokens.data(), seq.tokens.size(), output_ids);
}
// copy input tokens
if (input_length) {
output_ids = Copy(input_ids, input_length, output_ids);
}
// total context length (history + input)
state.h_context_length[i] = output_ids - output_ids_base;
state.h_finished[i] = false;
const int request_output_len = state.requests[i]->inputs[rank_].getVal<int>("request_output_len");
state.seq_len_limit[i] = state.h_context_length[i] + request_output_len;
// `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
// the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
if (state.seq_len_limit[i] >= session_len_) {
state.seq_len_limit[i] = session_len_ - 1;
if (rank_ == 0) {
const int trunc_output_len = state.seq_len_limit[i] - state.h_context_length[i];
TM_LOG_WARNING(
"[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
(long)seq.id,
state.h_context_length[i],
request_output_len,
(int)session_len_,
trunc_output_len);
}
}
// compute rope scaling factor
if (r->start_flag) {
seq.rope_theta = model_->attn_params_.rotary_embedding_base;
auto scaling_factor = 1.f;
if (r->inputs[rank_].isExist("rope_scaling_factor")) { // runtime scaling factor
scaling_factor = r->inputs[rank_].getVal<float>("rope_scaling_factor");
}
else if (model_->attn_params_.rope_scaling_factor >= 1.f) { // infer by `seq_len_limit`
scaling_factor = model_->attn_params_.rope_scaling_factor;
auto max_seq_len = state.seq_len_limit[i];
auto max_pos_emb = model_->attn_params_.max_position_embeddings;
if (max_seq_len > max_pos_emb) {
scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
// scaling_factor = std::max(exp2f(ceilf(log2f((float)max_seq_len / max_pos_emb) + 1.f))
// - 1.f, 1.f);
}
}
if (scaling_factor != 1.f) {
float rope_dim = model_->attn_params_.rotary_embedding_dim;
seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f));
TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f",
(long)seq.id,
scaling_factor,
seq.rope_theta);
}
}
state.h_rope_theta[i] = seq.rope_theta;
// recover device states if not a new sequence
if (!r->start_flag) {
Copy((curandState_t*)seq.random_state.data() + 0, 1, (curandState_t*)state.top_k_curand_state);
Copy((curandState_t*)seq.random_state.data() + 1, 1, (curandState_t*)state.top_p_curand_state);
}
// assign priority based on arrival time
r->priority = request_count_++;
// increment pointer
i++;
}
incoming_->size = i;
}
template<typename T>
void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len)
bool LlamaBatch<T>::Initialize()
{
NvtxScope scope("initialize");
std::vector<const Sequence*> sequences;
std::vector<Sequence::Status> status;
std::vector<uint64_t> priorities;
std::vector<int> context_lengths;
std::vector<std::pair<BatchState*, int>> coords;
// count the holes introduced by finished requests in from previous iteration or stop requests from
// current iteration
int holes{};
int active_holes{};
for (int i = 0; i < state_->size; ++i) {
if (!state_->requests[i]) {
++holes;
if (i < state_->active_size) {
++active_holes;
}
}
}
// dbg(holes, active_holes);
auto process = [&](BatchState* state) {
for (int i = 0; i < state->size; ++i) {
if (auto& r = state->requests[i]) {
sequences.push_back(state->sequences[i]);
status.push_back(state->sequences[i]->status);
priorities.push_back(r->priority);
context_lengths.push_back(state->h_context_length[i]);
coords.emplace_back(state, i);
// clear swap-in flags
state->is_swap_in[i] = 0;
}
}
};
process(state_);
process(incoming_);
auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_);
if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
dbg(outcome);
}
bool exchange = outcome.swap_in + outcome.swap_out > 0;
std::vector<int> idxs(sequences.size());
std::iota(idxs.begin(), idxs.end(), 0);
if (exchange || holes || incoming_->size) {
// put active ones first
auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) {
return sequences[idx]->status == Sequence::kActive; // present status
});
// all blocks are not enough to hold a single sequence
FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks.");
// move swap-ins to the back
auto swapin_beg = std::stable_partition(idxs.begin(), active_end, [&](int idx) {
return status[idx] == Sequence::kActive; // past status
});
// sort swap-ins according to missing length
if (swapin_beg != active_end) {
std::vector<int> missing_len(sequences.size());
for (int i = 0; i < sequences.size(); ++i) {
missing_len[i] = context_lengths[i] - sequences[i]->cache_len;
}
std::stable_sort(swapin_beg, active_end, [&](int i, int j) { return missing_len[i] < missing_len[j]; });
}
// Copy sequence states to back buffer
FT_CHECK(back_->size == 0 && back_->active_size == 0);
for (const auto& i : idxs) {
auto& s = *sequences[i];
if (exchange) {
const auto& [state, idx] = coords[i];
// backup random states from dynamic decode layers for swap-outs
if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
SaveRandomState(*state, idx);
}
// mark swap-ins
if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
state->is_swap_in[idx] = 1;
}
}
if (s.status == Sequence::kActive) {
++back_->active_size;
}
CopyState(coords[i], {back_, back_->size++});
}
// Swap the buffers
std::swap(state_, back_);
ClearState(*back_);
ClearState(*incoming_);
}
/// Update block ptrs when there were
// 1. swap-in or swap-out
// 2. holes in the active buffer
// 3. new allocations (for existing active sequences)
if (exchange || active_holes || outcome.allocation) {
// Prepare intermediate buffers
h_cu_block_counts_[0] = 0;
auto k_ptrs = h_k_block_ptrs_;
auto v_ptrs = h_v_block_ptrs_;
const int batch_size = state_->active_size;
for (int i = 0; i < batch_size; ++i) {
const auto& seq = *state_->sequences[i];
// cumulative num of blocks
h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();
k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
});
v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](auto p) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetVal(p->data));
});
}
static_assert(sizeof(uintptr_t) == sizeof(void*));
Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
}
/// Layout of the buffers is changed, generation & sampling need to be re-initialized for correctness when there
/// were
// 1. swap-in or swap-out
// 2. holes in the active buffer
return exchange || active_holes;
}
template<typename T>
void LlamaBatch<T>::CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst)
{
const auto& [src, i] = _src;
const auto& [dst, j] = _dst;
FT_CHECK((bool)src->requests[i]);
FT_CHECK(!(bool)dst->requests[j]);
dst->h_context_length[j] = src->h_context_length[i];
dst->h_finished[j] = src->h_finished[i];
dst->h_rope_theta[j] = src->h_rope_theta[i];
dst->seq_len_limit[j] = src->seq_len_limit[i];
dst->sequences[j] = src->sequences[i];
dst->is_swap_in[j] = src->is_swap_in[i];
dst->requests[j] = std::move(src->requests[i]);
Copy(src->output_ids + i * session_len_, src->h_context_length[i], dst->output_ids + j * session_len_);
Copy((curandState_t*)src->top_k_curand_state + i, 1, (curandState_t*)dst->top_k_curand_state + j);
Copy((curandState_t*)src->top_p_curand_state + i, 1, (curandState_t*)dst->top_p_curand_state + j);
}
template<typename T>
void LlamaBatch<T>::SaveRandomState(BatchState& state, int idx)
{
Copy(model_->GetTopKState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
Copy(model_->GetTopPState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
}
template<typename T>
void LlamaBatch<T>::LoadRandomState(BatchState& state, int idx)
{
dbg(idx);
Copy((curandState_t*)state.top_k_curand_state + idx, 1, model_->GetTopKState(idx));
Copy((curandState_t*)state.top_p_curand_state + idx, 1, model_->GetTopPState(idx));
}
template<typename T>
void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
const size_t batchxbeam = batch_size;
const size_t hidden_units = llama_->hidden_units_;
const size_t vocab_size = llama_->vocab_size_padded_;
const size_t hidden_units = model_->hidden_units_;
const size_t vocab_size = model_->vocab_size_padded_;
const size_t head_dim = model_->size_per_head_;
const size_t local_kv_head_num = model_->local_kv_head_num_;
// +1 padding, BlockIterator does not use predicate
const size_t max_block_count = sequence_manager_->max_block_count() + 1;
context_decoder_input_buf_ =
(T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
......@@ -170,19 +497,26 @@ void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len)
context_decoder_ids_buf_ =
(int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false);
tmp_k_cache_buf_ = (T*)allocator_->reMalloc(
tmp_k_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false);
tmp_v_cache_buf_ = (T*)allocator_->reMalloc(
tmp_v_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false);
tmp_k_ptrs_ = (void**)allocator_->reMalloc(tmp_k_ptrs_, sizeof(void*) * batch_size, false);
tmp_v_ptrs_ = (void**)allocator_->reMalloc(tmp_v_ptrs_, sizeof(void*) * batch_size, false);
decoder_input_buf_ = (T*)allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units, false);
decoder_output_buf_ = (T*)allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units, false);
input_ids_buf_ = (int*)allocator_->reMalloc(input_ids_buf_, sizeof(int) * batchxbeam * session_len, true);
input_length_buf_ = (int*)allocator_->reMalloc(input_length_buf_, sizeof(int) * batchxbeam);
history_length_buf_ = (int*)allocator_->reMalloc(history_length_buf_, sizeof(int) * batchxbeam);
context_length_buf_ = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam);
total_padding_count_ = (int*)allocator_->reMalloc(total_padding_count_, sizeof(int) * batchxbeam, false);
sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
k_cache_ptr_buf_ = (uint64_t*)allocator_->reMalloc(k_cache_ptr_buf_, sizeof(uint64_t) * batchxbeam);
v_cache_ptr_buf_ = (uint64_t*)allocator_->reMalloc(v_cache_ptr_buf_, sizeof(uint64_t) * batchxbeam);
cu_block_counts_ = (int*)allocator_->reMalloc(cu_block_counts_, sizeof(int) * (batch_size + 1));
k_block_ptrs_ = (uintptr_t*)allocator_->reMalloc(k_block_ptrs_, sizeof(uintptr_t) * max_block_count);
v_block_ptrs_ = (uintptr_t*)allocator_->reMalloc(v_block_ptrs_, sizeof(uintptr_t) * max_block_count);
logits_buf_ = (float*)allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);
local_logits_buf_ = (float*)allocator_->reMalloc(local_logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);
......@@ -193,14 +527,18 @@ void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len)
finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false);
seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false);
request_output_ids_ptrs_ = (int**)allocator_->reMalloc(request_output_ids_ptrs_, sizeof(int*) * batch_size, true);
request_output_ids_lens_ = (int*)allocator_->reMalloc(request_output_ids_lens_, sizeof(int) * batch_size, true);
request_seqlen_ptrs_ = (int**)allocator_->reMalloc(request_seqlen_ptrs_, sizeof(int*) * batch_size, true);
rope_theta_ = (float*)allocator_->reMalloc(rope_theta_, sizeof(float) * batch_size, false);
is_allocate_buffer_ = true;
}
template<typename T>
void LlamaBatch<T>::allocatePersistantBuffer(size_t max_batch_size)
void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
{
output_ids_buf_ = (int*)allocator_->reMalloc(output_ids_buf_, sizeof(int) * max_batch_size * session_len_, true);
stop_words_buf_ =
(int*)allocator_->reMalloc(stop_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
bad_words_buf_ =
......@@ -221,35 +559,54 @@ void LlamaBatch<T>::allocatePersistantBuffer(size_t max_batch_size)
{"repetition_penalty", h_repetition_penalty_},
{"random_seed", h_random_seed_}};
topk_curandstate_buf_ = allocator_->reMalloc(topk_curandstate_buf_, sizeof(curandState_t) * max_batch_size, true);
topp_curandstate_buf_ = allocator_->reMalloc(topp_curandstate_buf_, sizeof(curandState_t) * max_batch_size, true);
for (auto& s : states_) {
s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true);
s.top_k_curand_state = allocator_->reMalloc(s.top_k_curand_state, sizeof(curandState_t) * max_batch_size, true);
s.top_p_curand_state = allocator_->reMalloc(s.top_p_curand_state, sizeof(curandState_t) * max_batch_size, true);
}
const size_t max_block_count = sequence_manager_->max_block_count();
{
NcclGuard barrier(llama_->tensor_para_, stream_, true);
NcclGuard barrier(model_->tensor_para_, stream_, true);
h_input_ids_buf_ =
(int*)allocator_->reMalloc(h_input_ids_buf_, sizeof(int) * max_batch_size * session_len_, false, true);
h_input_length_buf_ =
(int*)allocator_->reMalloc(h_input_length_buf_, sizeof(int) * max_batch_size, false, true);
h_history_length_buf_ =
(int*)allocator_->reMalloc(h_history_length_buf_, sizeof(int) * max_batch_size, false, true);
h_context_length_buf_ =
(int*)allocator_->reMalloc(h_context_length_buf_, sizeof(int) * max_batch_size, false, true);
h_sequence_lengths_ =
(int*)allocator_->reMalloc(h_sequence_lengths_, sizeof(int) * max_batch_size, false, true);
h_k_cache_ptr_buf_ =
(uintptr_t*)allocator_->reMalloc(h_k_cache_ptr_buf_, sizeof(uintptr_t) * max_batch_size, true, true);
h_v_cache_ptr_buf_ =
(uintptr_t*)allocator_->reMalloc(h_v_cache_ptr_buf_, sizeof(uintptr_t) * max_batch_size, true, true);
h_finished_buf_ = (bool*)allocator_->reMalloc(h_finished_buf_, sizeof(bool) * max_batch_size, false, true);
h_tmp_k_ptrs_ = (void**)allocator_->reMalloc(h_tmp_k_ptrs_, sizeof(void*) * max_batch_size, false, true);
h_tmp_v_ptrs_ = (void**)allocator_->reMalloc(h_tmp_v_ptrs_, sizeof(void*) * max_batch_size, false, true);
h_cu_block_counts_ =
(int*)allocator_->reMalloc(h_cu_block_counts_, sizeof(int) * (max_batch_size + 1), false, true);
h_k_block_ptrs_ =
(uintptr_t*)allocator_->reMalloc(h_k_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);
h_v_block_ptrs_ =
(uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);
for (auto& s : states_) {
s.h_context_length =
(int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true);
s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
s.h_rope_theta = (float*)allocator_->reMalloc(s.h_rope_theta, sizeof(float) * max_batch_size, false, true);
}
h_seq_limit_len_ =
(uint32_t*)allocator_->reMalloc(h_seq_limit_len_, sizeof(uint32_t) * max_batch_size, false, true);
h_request_output_ids_ptrs_ =
(int**)allocator_->reMalloc(h_request_output_ids_ptrs_, sizeof(int*) * max_batch_size, true, true);
h_request_output_ids_lens_ =
(int*)allocator_->reMalloc(h_request_output_ids_lens_, sizeof(int) * max_batch_size, true, true);
h_request_seqlen_ptrs_ =
(int**)allocator_->reMalloc(h_request_seqlen_ptrs_, sizeof(int*) * max_batch_size, true, true);
}
is_allocate_persistant_buffer_ = true;
}
template<typename T>
void LlamaBatch<T>::freeBuffer()
void LlamaBatch<T>::FreeBuffer()
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) {
......@@ -257,19 +614,23 @@ void LlamaBatch<T>::freeBuffer()
allocator_->free((void**)&context_decoder_output_buf_);
allocator_->free((void**)&context_decoder_ids_buf_);
allocator_->free((void**)&tmp_k_cache_buf_);
allocator_->free((void**)&tmp_v_cache_buf_);
allocator_->free((void**)&tmp_k_ptrs_);
allocator_->free((void**)&tmp_v_ptrs_);
allocator_->free((void**)&decoder_input_buf_);
allocator_->free((void**)&decoder_output_buf_);
allocator_->free((void**)&input_ids_buf_);
allocator_->free((void**)&input_length_buf_);
allocator_->free((void**)&history_length_buf_);
allocator_->free((void**)&context_length_buf_);
allocator_->free((void**)&total_padding_count_);
allocator_->free((void**)&sequence_lengths_);
allocator_->free((void**)&k_cache_ptr_buf_);
allocator_->free((void**)&v_cache_ptr_buf_);
allocator_->free((void**)&cu_block_counts_);
allocator_->free((void**)&k_block_ptrs_);
allocator_->free((void**)&v_block_ptrs_);
allocator_->free((void**)&logits_buf_);
allocator_->free((void**)&local_logits_buf_);
......@@ -287,75 +648,101 @@ void LlamaBatch<T>::freeBuffer()
allocator_->free((void**)&finished_buf_);
allocator_->free((void**)&seq_limit_len_);
allocator_->free((void**)&request_output_ids_ptrs_);
allocator_->free((void**)&request_output_ids_lens_);
allocator_->free((void**)&request_seqlen_ptrs_);
allocator_->free((void**)&rope_theta_);
is_allocate_buffer_ = false;
}
if (is_allocate_persistant_buffer_) {
for (auto& s : states_) {
allocator_->free((void**)&s.h_context_length, true);
allocator_->free((void**)&s.h_finished, true);
allocator_->free((void**)&s.h_rope_theta, true);
allocator_->free((void**)&s.output_ids);
}
allocator_->free((void**)&h_tmp_k_ptrs_, true);
allocator_->free((void**)&h_tmp_v_ptrs_, true);
allocator_->free((void**)&h_cu_block_counts_, true);
allocator_->free((void**)&h_k_block_ptrs_, true);
allocator_->free((void**)&h_v_block_ptrs_, true);
allocator_->free((void**)&h_input_ids_buf_, true);
allocator_->free((void**)&h_input_length_buf_, true);
allocator_->free((void**)&h_history_length_buf_, true);
allocator_->free((void**)&h_context_length_buf_, true);
allocator_->free((void**)&h_sequence_lengths_, true);
allocator_->free((void**)&h_k_cache_ptr_buf_, true);
allocator_->free((void**)&h_v_cache_ptr_buf_, true);
allocator_->free((void**)&h_seq_limit_len_, true);
allocator_->free((void**)&h_finished_buf_, true);
allocator_->free((void**)&output_ids_buf_);
allocator_->free((void**)&h_request_output_ids_ptrs_, true);
allocator_->free((void**)&h_request_output_ids_lens_, true);
allocator_->free((void**)&h_request_seqlen_ptrs_, true);
is_allocate_persistant_buffer_ = false;
}
}
template<typename T>
LlamaBatch<T>::LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2<T>* llama):
LlamaBatch<T>::LlamaBatch(int max_batch_size,
int max_context_token_num,
int session_len,
std::unique_ptr<SequenceManager> sequence_manager,
LlamaV2<T>* model):
max_batch_size_(max_batch_size),
max_context_token_num_(max_context_token_num),
session_len_(session_len),
rank_(llama->tensor_para_.rank_),
debug_(llama->debug_),
llama_(llama),
rank_(model->tensor_para_.rank_),
debug_(model->debug_),
step_length_(model->step_length_),
sequence_manager_(std::move(sequence_manager)),
model_(model),
data_type_(getTensorType<T>())
{
stream_ = llama_->stream_;
allocator_ = llama_->allocator_;
cublas_wrapper_ = llama_->cublas_wrapper_;
stream_ = model_->stream_;
allocator_ = model_->allocator_;
cublas_wrapper_ = model_->cublas_wrapper_;
for (auto& s : states_) {
s.requests.resize(max_batch_size);
s.sequences.resize(max_batch_size);
s.seq_len_limit.resize(max_batch_size);
s.is_swap_in.resize(max_batch_size);
}
requests_.resize(max_batch_size);
request_seq_len_limit_.resize(max_batch_size);
cached_seq_.resize(max_batch_size);
state_ = &states_[0];
back_ = &states_[1];
incoming_ = &states_[2];
allocatePersistantBuffer(max_batch_size);
AllocateBuffer(max_batch_size, session_len_);
AllocatePersistantBuffer(max_batch_size);
}
template<typename T>
void LlamaBatch<T>::initializeSampling(int infer_request_count)
void LlamaBatch<T>::InitializeSampling()
{
const int batch_size = state_->active_size;
TensorMap inputs;
for (const auto& param : sampling_params_) {
// find an exemplar that matches the param name
const Tensor* ptr{};
for (int i = 0; i < batch_size_; ++i) {
if (requests_[i]->inputs[rank_].isExist(param.first)) {
ptr = &requests_[i]->inputs[rank_].at(param.first);
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]->inputs[rank_].isExist(param.first)) {
ptr = &state_->requests[i]->inputs[rank_].at(param.first);
break;
}
}
// fill the batch of the param
if (ptr) {
const auto& ref = *ptr;
auto shape = ref.shape;
FT_CHECK(shape[0] == 1);
shape[0] = batch_size_;
shape[0] = batch_size;
const int size_in_bytes = ref.sizeBytes();
check_cuda_error(cudaMemsetAsync(param.second, 0, size_in_bytes * batch_size_, stream_));
for (int i = 0; i < batch_size_; ++i) {
if (requests_[i]->inputs[rank_].isExist(param.first)) {
auto& src = requests_[i]->inputs[rank_].at(param.first);
Clear((std::byte*)param.second, size_in_bytes * batch_size);
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]->inputs[rank_].isExist(param.first)) {
auto& src = state_->requests[i]->inputs[rank_].at(param.first);
FT_CHECK(ref.shape == src.shape);
check_cuda_error(cudaMemcpyAsync((uint8_t*)param.second + size_in_bytes * i,
src.getPtr<void>(),
size_in_bytes,
cudaMemcpyDefault,
stream_));
Copy(src.getPtr<std::byte>(), size_in_bytes, (std::byte*)param.second + size_in_bytes * i);
}
}
inputs.insert({param.first, {ref.where, ref.type, shape, param.second}});
......@@ -367,35 +754,27 @@ void LlamaBatch<T>::initializeSampling(int infer_request_count)
inputs_ = std::move(inputs);
llama_->dynamic_decode_layer_->setup(batch_size_, 1, &inputs_);
for (int i = 0; i < batch_size_; ++i) {
// recover random states if not a new request or new request w/o "random_seed"
if (i < batch_size_ - infer_request_count || !requests_[i]->inputs[rank_].isExist("random_seed")) {
check_cuda_error(cudaMemcpyAsync(llama_->dynamic_decode_layer_->topk_curandstate_buf() + i,
(curandState_t*)topk_curandstate_buf_ + i,
sizeof(curandState_t),
cudaMemcpyDefault,
stream_));
check_cuda_error(cudaMemcpyAsync(llama_->dynamic_decode_layer_->topp_curandstate_buf() + i,
(curandState_t*)topp_curandstate_buf_ + i,
sizeof(curandState_t),
cudaMemcpyDefault,
stream_));
model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);
// recover random states if not a new request
for (int i = 0; i < batch_size; ++i) {
if (!state_->requests[i]->start_flag && state_->is_swap_in[i]) {
LoadRandomState(*state_, i);
}
}
handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_);
handleOptArg(&inputs_, "end_id", end_ids_buf_, model_->end_id_, batch_size);
cudaStreamSynchronize(0);
}
template<typename T>
void LlamaBatch<T>::initializeGeneration()
auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
{
max_context_len_ = *std::max_element(h_context_length_buf_, h_context_length_buf_ + batch_size_);
const int batch_size = state_->active_size;
const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
check_cuda_error(cudaMemsetAsync(token_ids_buf_, 0, sizeof(int) * batch_size_ * session_len_ * 2, stream_));
invokeTransposeAxis01(token_ids_buf_, output_ids_buf_, batch_size_, session_len_, 1, stream_);
Clear(token_ids_buf_, batch_size * session_len_);
invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
sync_check_cuda_error();
// token_ids_buf_[s, b]
......@@ -404,125 +783,134 @@ void LlamaBatch<T>::initializeGeneration()
// ABCDEFGHi -> ABCDEFGHi i
// ABCDEFGh ABCDEFGh h
// ABCd ABCd d
for (int i = 0; i < batch_size_; ++i) {
for (int i = 0; i < batch_size; ++i) {
auto token_ids = token_ids_buf_ + i;
auto p_src = h_context_length_buf_[i] - 1;
auto p_dst = max_context_len_ - 1;
auto p_src = state_->h_context_length[i] - 1;
auto p_dst = max_context_len - 1;
if (p_src != p_dst) { // dst and src of `cudaMemcpyAsync` must not overlap
check_cuda_error(cudaMemcpyAsync(token_ids + p_dst * batch_size_,
token_ids + p_src * batch_size_,
sizeof(int),
cudaMemcpyDefault,
stream_));
Copy(token_ids + p_src * batch_size, 1, token_ids + p_dst * batch_size);
}
}
check_cuda_error(cudaMemcpyAsync(
context_length_buf_, h_context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
check_cuda_error(cudaMemcpyAsync(
k_cache_ptr_buf_, h_k_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
check_cuda_error(cudaMemcpyAsync(
v_cache_ptr_buf_, h_v_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
check_cuda_error(
cudaMemcpyAsync(sequence_lengths_, context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
Copy(state_->h_context_length, batch_size, context_length_buf_); // also referenced in `SetOutputTensors`
Copy(context_length_buf_, batch_size, sequence_lengths_);
// `sequence_lengths_` will be increased by dynamic decode
// note that in decoder and in output "sequence length" has different semantic
// - in decoder it means length of sequence that has kv cache already computed
// - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
invokePlusScalar(sequence_lengths_, -1, batch_size_, stream_);
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
sync_check_cuda_error();
// total_padding_count_
// decoding starts at max_context_len
check_cuda_error(cudaMemsetAsync(total_padding_count_, 0, sizeof(int) * batch_size_, stream_));
invokeUpdatePaddingCount(total_padding_count_, //
context_length_buf_,
max_context_len_,
batch_size_,
1,
stream_);
sync_check_cuda_error();
// used for dispatching split-k decoding kernels
const int sum_seq_len =
std::accumulate(state_->h_context_length, state_->h_context_length + batch_size, -batch_size);
const int max_seq_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size) - 1;
// seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted
// for
for (int i = 0; i < batch_size; ++i) {
h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len - state_->h_context_length[i]);
if (max_context_len >= h_seq_limit_len_[i]) { // mask finished sequences
state_->h_finished[i] = true;
}
}
Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
Copy(state_->h_finished, batch_size, finished_buf_);
// seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted for
for (int i = 0; i < batch_size_; ++i) {
h_seq_limit_len_[i] = request_seq_len_limit_[i] + (max_context_len_ - h_context_length_buf_[i]);
// mask finished sequences
h_finished_buf_[i] = max_context_len_ >= h_seq_limit_len_[i];
for (int i = 0; i < batch_size; ++i) {
Tensor& output_ids = state_->requests[i]->outputs[rank_].at("output_ids");
int* req_output_ids_ptr = output_ids.getPtr<int>();
int* req_seqlen_ptr = state_->requests[i]->outputs[rank_].getPtr<int>("sequence_length");
h_request_output_ids_ptrs_[i] = req_output_ids_ptr;
h_request_output_ids_lens_[i] = output_ids.shape.at(2);
h_request_seqlen_ptrs_[i] = req_seqlen_ptr;
FT_CHECK(h_request_output_ids_ptrs_[i]);
FT_CHECK(h_request_output_ids_lens_[i]);
FT_CHECK(h_request_seqlen_ptrs_[i]);
}
check_cuda_error(
cudaMemcpyAsync(seq_limit_len_, h_seq_limit_len_, sizeof(uint32_t) * batch_size_, cudaMemcpyDefault, stream_));
check_cuda_error(
cudaMemcpyAsync(finished_buf_, h_finished_buf_, sizeof(bool) * batch_size_, cudaMemcpyDefault, stream_));
Copy(h_request_output_ids_ptrs_, batch_size, request_output_ids_ptrs_);
Copy(h_request_output_ids_lens_, batch_size, request_output_ids_lens_);
Copy(h_request_seqlen_ptrs_, batch_size, request_seqlen_ptrs_);
Copy(state_->h_rope_theta, batch_size, rope_theta_);
// ! range of step_ [1, 2 * session_len]
// consider a sequence with context_len == session_len and another sequence with context_len == 1 and
// request_output_len == session_len - 1 => step_ will loop in [session_len, 2 * session_len)
step_ = max_context_len_;
const int start_step = max_context_len;
if (rank_ == 0) {
TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size_);
TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len_);
TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size);
TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len);
TM_LOG_INFO("[initGen] slot sequence_id context_len seq_limit_len finished");
for (int i = 0; i < batch_size_; ++i) {
for (int i = 0; i < batch_size; ++i) {
TM_LOG_INFO("[initGen] %4d %11ld %11d %13d %8d",
i,
(long)cached_seq_[i].id,
h_context_length_buf_[i],
(long)state_->sequences[i]->id,
state_->h_context_length[i],
(int)h_seq_limit_len_[i],
(int)h_finished_buf_[i]);
(int)state_->h_finished[i]);
}
}
// for (int i = 0; i < batch_size; ++i) {
// gSequenceIds(i) = state_->requests[i]->id;
// }
return GenerationState{max_context_len, start_step, sum_seq_len, max_seq_len};
}
template<typename T>
bool LlamaBatch<T>::generate()
bool LlamaBatch<T>::Generate(GenerationState& g)
{
NvtxScope scope("Generate");
const int batch_size = state_->active_size;
constexpr int kLogInterval = 10;
if (rank_ == 0 && (step_ - 1) % kLogInterval == 0) {
TM_LOG_INFO("------------------------- step = %d -------------------------", step_ - 1);
if (rank_ == 0 && (g.step - 1) % kLogInterval == 0) {
TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1);
}
const bool is_first_step = step_ == max_context_len_;
const bool is_first_step = (g.step == g.max_init_ctx_len);
std::vector<int> prev;
if (debug_ && rank_ == 0 && is_first_step) {
prev.resize(batch_size_);
cudaMemcpyAsync(prev.data(),
token_ids_buf_ + (step_ - 1) * batch_size_,
sizeof(int) * batch_size_,
cudaMemcpyDefault,
stream_);
prev.resize(batch_size);
Copy(token_ids_buf_ + (g.step - 1) * batch_size, batch_size, prev.data());
}
// embeddingLookup(step_ - 1);
llama_->embeddingLookup(decoder_input_buf_, //
model_->embeddingLookup(decoder_input_buf_, //
token_ids_buf_,
batch_size_,
step_ - 1);
batch_size,
g.step - 1);
llama_->decoderForward(decoder_output_buf_,
k_cache_ptr_buf_,
v_cache_ptr_buf_,
model_->decoderForward(decoder_output_buf_,
k_block_ptrs_,
v_block_ptrs_,
decoder_input_buf_,
sequence_lengths_,
total_padding_count_,
finished_buf_,
step_,
cu_block_counts_,
rope_theta_,
g.step,
0,
session_len_,
batch_size_);
g.sum_seq_len,
g.max_seq_len,
batch_size);
llama_->postDecodeEmbedding(logits_buf_, //
model_->postDecodeEmbedding(logits_buf_, //
local_logits_buf_,
decoder_output_buf_,
batch_size_);
batch_size);
// stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
// not supported yet.
bool should_stop{};
llama_->dynamicDecode(token_ids_buf_,
model_->dynamicDecode(token_ids_buf_,
finished_buf_,
sequence_lengths_,
&should_stop,
......@@ -532,17 +920,16 @@ bool LlamaBatch<T>::generate()
seq_limit_len_,
context_length_buf_,
end_ids_buf_,
step_,
g.step,
0,
max_context_len_,
g.max_init_ctx_len,
session_len_ * 2,
batch_size_);
batch_size);
if (debug_ && rank_ == 0) {
std::vector<int> curr(batch_size_);
std::vector<int> curr(batch_size);
cudaMemcpyAsync(
curr.data(), token_ids_buf_ + step_ * batch_size_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_);
Copy(token_ids_buf_ + g.step * batch_size, batch_size, curr.data());
cudaStreamSynchronize(stream_);
if (is_first_step) {
......@@ -550,342 +937,189 @@ bool LlamaBatch<T>::generate()
for (int k = 0; k < prev.size(); ++k) {
sprev << std::setw(6) << prev[k];
}
TM_LOG_INFO("[ lookup ] step = %d, [%s]", step_ - 1, sprev.str().c_str());
TM_LOG_INFO("[ lookup ] step = %d, [%s]", g.step - 1, sprev.str().c_str());
}
std::stringstream scurr;
for (int k = 0; k < curr.size(); ++k) {
scurr << std::setw(6) << curr[k];
}
TM_LOG_INFO("[generate] step = %d, [%s]", step_ - 1, scurr.str().c_str());
TM_LOG_INFO("[generate] step = %d, [%s]", g.step - 1, scurr.str().c_str());
}
////////////////////////////////////////////////
/// ! increase the step counter
++step_;
/// ! increase the counters
g.step += 1;
g.max_seq_len += 1;
g.sum_seq_len += batch_size;
return !should_stop;
}
template<typename T>
void LlamaBatch<T>::initialize(const std::vector<std::shared_ptr<Request>>& infer_requests)
void LlamaBatch<T>::ContextDecode()
{
FT_CHECK(batch_size_ + infer_requests.size() <= max_batch_size_);
const int infer_request_count = infer_requests.size();
allocateBuffer(batch_size_ + infer_request_count, session_len_);
// handle infer requests
std::vector<int> tmp_input_length(infer_request_count);
std::vector<CachedSeq> tmp_cached_seq;
tmp_cached_seq.reserve(infer_request_count);
int tmp_max_input_length = 0;
for (int i = 0; i < infer_request_count; ++i) {
auto& r = *infer_requests[i];
LlamaCacheManager::Sequence seq{};
if (r.start_flag) {
seq = llama_->kv_cache_mgr_->create(r.id, stream_);
}
else {
seq = llama_->kv_cache_mgr_->fetch(r.id, stream_);
}
const auto batch_size = state_->active_size;
const int step = r.inputs[rank_].getVal<int>("step", -1);
if (step >= 0) {
if (step <= seq.token_ids.size()) {
seq.token_ids.resize(step);
seq.cache_len = std::min(seq.cache_len, (size_t)step);
}
else if (rank_ == 0) {
TM_LOG_WARNING("[initialize] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id);
int base = -1;
for (int i = 0; i < batch_size; ++i) {
if (state_->is_swap_in[i]) {
const auto& seq = *state_->sequences[i];
dbg(std::tuple(i, state_->h_context_length[i], seq.cache_len));
if (const int missing = state_->h_context_length[i] - seq.cache_len; missing > 1) {
base = base < 0 ? i : base;
dbg(seq.tokens, seq.cache_len);
Copy(state_->output_ids + i * session_len_ + seq.cache_len, missing, input_ids_buf_ + i * session_len_);
// subtract input/context len by 1 to skip last input token (will process with decoder later)
h_input_length_buf_[i] = missing - 1;
}
}
// input length with missing cache accounted for
int actual_input_len = r.inputs[rank_].getVal<int>("input_lengths") + (seq.token_ids.size() - seq.cache_len);
// insert `start_id` for empty sequences
if (seq.token_ids.empty() && actual_input_len == 0) {
seq.token_ids.push_back(llama_->start_id_);
seq.cache_len = 0;
actual_input_len = seq.token_ids.size() - seq.cache_len;
}
tmp_input_length[i] = actual_input_len;
tmp_max_input_length = std::max((int)tmp_max_input_length, actual_input_len);
tmp_cached_seq.push_back(std::move(seq));
}
FT_CHECK(tmp_max_input_length > 0);
const int max_input_length = tmp_max_input_length;
// arrange requests in ascending order w.r.t actual input lengths, so that requests need context decoding will
// be together
{
std::vector<int> idxs(tmp_input_length.size());
std::iota(idxs.begin(), idxs.end(), 0);
std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return tmp_input_length[i] < tmp_input_length[j]; });
for (int i = 0; i < idxs.size(); ++i) {
requests_[batch_size_ + i] = infer_requests[idxs[i]];
cached_seq_[batch_size_ + i] = tmp_cached_seq[idxs[i]];
}
if (base < 0) {
// TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
return;
}
const int count = batch_size_ + infer_requests.size();
std::vector<int> tmp_input_len(count);
const int context_decode_count = batch_size - base;
for (int i = batch_size_; i < count; ++i) {
const auto& seq = cached_seq_[i];
h_input_length_buf_[i] = requests_[i]->inputs[rank_].getVal<int>("input_lengths");
tmp_input_len[i] = h_input_length_buf_[i];
// prepare output ids
// <--------> max_context_len
// aaaAAAA
// bbbbBBBBBB
// ccCCC
auto output_ids_ptr = output_ids_buf_ + i * session_len_;
// clear the persistent buffer to prevent leaking previous conversation
check_cuda_error(cudaMemsetAsync(output_ids_ptr, 0, sizeof(int) * session_len_, stream_));
if (!seq.token_ids.empty()) {
check_cuda_error(cudaMemcpyAsync(output_ids_ptr, //
seq.token_ids.data(),
sizeof(int) * seq.token_ids.size(),
cudaMemcpyDefault,
stream_));
output_ids_ptr += seq.token_ids.size();
}
Copy(state_->h_context_length, batch_size, context_length_buf_);
Copy(state_->h_rope_theta, batch_size, rope_theta_);
Copy(h_input_length_buf_, batch_size, input_length_buf_);
if (h_input_length_buf_[i]) {
auto input_ids_ptr = requests_[i]->inputs[rank_].getPtr<int>("input_ids");
check_cuda_error(cudaMemcpyAsync(output_ids_ptr, //
input_ids_ptr,
sizeof(int) * h_input_length_buf_[i],
cudaMemcpyDefault,
stream_));
}
check_cuda_error(cudaStreamSynchronize(stream_));
const auto tick = std::chrono::high_resolution_clock::now();
if (!requests_[i]->start_flag && !seq.random_state_.empty()) {
check_cuda_error(cudaMemcpyAsync((curandState_t*)topk_curandstate_buf_ + i,
seq.random_state_.data(),
sizeof(curandState_t),
cudaMemcpyDefault,
stream_));
check_cuda_error(cudaMemcpyAsync((curandState_t*)topp_curandstate_buf_ + i,
seq.random_state_.data() + sizeof(curandState_t),
sizeof(curandState_t),
cudaMemcpyDefault,
stream_));
}
if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
}
for (int i = batch_size_; i < count; ++i) {
const auto& seq = cached_seq_[i];
const int missed = (int)seq.token_ids.size() - seq.cache_len;
auto input_ids_buf = input_ids_buf_ + i * session_len_;
FT_CHECK(missed >= 0);
if (missed > 0) {
check_cuda_error(cudaMemcpyAsync(input_ids_buf, //
seq.token_ids.data() + seq.cache_len,
sizeof(int) * missed,
cudaMemcpyDefault,
stream_));
input_ids_buf += missed;
// subtract input/context len by 1 to skip last input token (will process with decoder later)
invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_);
// find sub-batch offsets
std::vector<int> offsets{base};
std::vector<int> max_context_cnts;
int accum_size = 0;
int accum_input_count = 0;
int max_context_count = 0;
for (int i = base; i < batch_size; ++i) {
int size = accum_size + 1;
int input_count = accum_input_count + h_input_length_buf_[i];
int context_count = std::max(max_context_count, state_->h_context_length[i] - 1);
// we have `cu_seqlens` on q so no padding for input is needed
// kernels are expecting uniform k/v cache length -> `max_context_count * size <= max_context_token_num_`
if (input_count <= max_context_token_num_ && context_count * size <= max_context_token_num_) {
accum_size = size;
accum_input_count = input_count;
max_context_count = context_count;
}
auto& input_ids = requests_[i]->inputs[rank_].at("input_ids");
check_cuda_error(cudaMemcpyAsync(input_ids_buf, //
input_ids.getPtr<int>(),
sizeof(int) * h_input_length_buf_[i],
cudaMemcpyDefault,
stream_));
h_input_length_buf_[i] += missed;
h_history_length_buf_[i] = seq.cache_len;
h_context_length_buf_[i] = h_input_length_buf_[i] + h_history_length_buf_[i];
const int request_output_len = requests_[i]->inputs[rank_].getVal<int>("request_output_len");
request_seq_len_limit_[i] = h_context_length_buf_[i] + request_output_len;
// `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
// the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
if (request_seq_len_limit_[i] >= session_len_) {
request_seq_len_limit_[i] = session_len_ - 1;
if (rank_ == 0) {
const int trunc_output_len = request_seq_len_limit_[i] - h_context_length_buf_[i];
TM_LOG_WARNING(
"[initialize] [%ld] total sequence length (%d + %d) exceeds session_len (%d), request_output_len is truncated to %d",
(long)seq.id,
h_context_length_buf_[i],
request_output_len,
(int)session_len_,
trunc_output_len);
}
else {
offsets.push_back(i);
max_context_cnts.push_back(max_context_count);
accum_size = 1;
accum_input_count = h_input_length_buf_[i];
max_context_count = state_->h_context_length[i] - 1;
}
h_k_cache_ptr_buf_[i] = (uint64_t)seq.k_cache;
h_v_cache_ptr_buf_[i] = (uint64_t)seq.v_cache;
}
const int max_context_len = *std::max_element(h_context_length_buf_ + batch_size_, h_context_length_buf_ + count);
batch_size_ = count;
max_context_len_ = max_context_len;
step_ = max_context_len;
check_cuda_error(
cudaMemcpyAsync(input_length_buf_, h_input_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
check_cuda_error(cudaMemcpyAsync(
history_length_buf_, h_history_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
check_cuda_error(cudaMemcpyAsync(
context_length_buf_, h_context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
check_cuda_error(cudaMemcpyAsync(
k_cache_ptr_buf_, h_k_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
check_cuda_error(cudaMemcpyAsync(
v_cache_ptr_buf_, h_v_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_));
if (llama_->tensor_para_.rank_ == 0) {
TM_LOG_INFO("[init] infer_request_count = %d", (int)infer_request_count);
TM_LOG_INFO("[init] batch_size = %d", (int)batch_size_);
TM_LOG_INFO("[init] session_len = %d", (int)session_len_);
TM_LOG_INFO("[init] max_input_length = %d", (int)max_input_length);
TM_LOG_INFO("[init] max_context_len = %d", (int)max_context_len);
TM_LOG_INFO(
"[init] slot sequence_id history_len input_len context_len tmp_input_len token_ids.size cache_len");
for (int i = batch_size_ - infer_request_count; i < batch_size_; ++i) {
TM_LOG_INFO("[init] %4d %11ld %11d %9d %11d %13d %14d %9d",
i,
(int)cached_seq_[i].id,
h_history_length_buf_[i],
h_input_length_buf_[i],
h_context_length_buf_[i],
tmp_input_len[i],
(int)cached_seq_[i].token_ids.size(),
(int)cached_seq_[i].cache_len);
offsets.push_back(batch_size);
max_context_cnts.push_back(max_context_count);
dbg(offsets, max_context_cnts);
// context decode on sub-batches
for (int k = 0; k < offsets.size() - 1; ++k) {
int first = offsets[k];
int last = offsets[k + 1];
int sub_batch_size = last - first;
T* k_ptr = tmp_k_cache_buf_;
T* v_ptr = tmp_v_cache_buf_;
std::vector<int> decode_indices{};
std::vector<int> decode_lengths{};
int max_input_len{};
auto input_ids = context_decoder_ids_buf_;
TM_LOG_INFO("first = %d, last = %d", first, last);
for (int i = first; i < last; ++i) {
TM_LOG_INFO("session_len = %d, input_length = %d", session_len_, h_input_length_buf_[i]);
input_ids = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
dbg(i, h_input_length_buf_[i]);
h_tmp_k_ptrs_[i] = k_ptr;
h_tmp_v_ptrs_[i] = v_ptr;
k_ptr += model_->local_kv_head_num_ * max_context_cnts[k] * model_->size_per_head_;
v_ptr += model_->local_kv_head_num_ * max_context_cnts[k] * model_->size_per_head_;
decode_indices.push_back(i);
decode_lengths.push_back(h_input_length_buf_[i]);
max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
}
}
}
int token_count = input_ids - context_decoder_ids_buf_;
dbg(token_count, max_input_len, max_context_cnts[k]);
template<typename T>
void LlamaBatch<T>::contextDecode()
{
int base = -1;
for (int i = 0; i < batch_size_; ++i) {
if (h_input_length_buf_[i] > 1) {
base = i;
break;
}
}
if (base >= 0) {
check_cuda_error(cudaStreamSynchronize(stream_));
const auto tick = std::chrono::high_resolution_clock::now();
Copy(h_tmp_k_ptrs_ + first, sub_batch_size, tmp_k_ptrs_ + first);
Copy(h_tmp_v_ptrs_ + first, sub_batch_size, tmp_v_ptrs_ + first);
const int context_decode_count = batch_size_ - base;
if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
}
invokePlusScalar(input_length_buf_ + base, -1, context_decode_count, stream_);
invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_);
auto get_input_len = [this](int index) { return h_input_length_buf_[index] - 1; };
auto get_context_len = [this](int index) { return h_context_length_buf_[index] - 1; };
std::vector<int> decode_indices{base};
std::vector<int> decode_lengths{get_input_len(base)};
auto token_num = get_input_len(base);
auto max_input_len = get_input_len(base);
auto max_context_len = get_context_len(base);
auto offset = base;
for (int i = offset + 1; i <= batch_size_; ++i) {
if (i == batch_size_ || token_num + h_context_length_buf_[i] > max_context_token_num_) {
const int context_decode_batch_size = i - offset;
if (rank_ == 0) {
TM_LOG_INFO(
"[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
base,
context_decode_batch_size,
token_num,
max_input_len,
max_context_len);
}
// construct context_decoder_ids w/o padding
// aaaa____
// bb______ -> aaaabbcccccccc
// cccccccc
auto context_decoder_ids = context_decoder_ids_buf_;
for (int j = offset; j < i; ++j) {
check_cuda_error(cudaMemcpyAsync(context_decoder_ids,
input_ids_buf_ + j * session_len_,
sizeof(int) * get_input_len(j),
cudaMemcpyDefault,
stream_));
context_decoder_ids += get_input_len(j);
}
llama_->contextDecode(nullptr,
k_cache_ptr_buf_ + offset,
v_cache_ptr_buf_ + offset,
context_decoder_input_buf_,
context_decoder_output_buf_,
context_decoder_ids_buf_,
input_length_buf_ + offset,
history_length_buf_ + offset,
context_length_buf_ + offset,
token_num,
max_input_len,
max_context_len,
session_len_,
context_decode_batch_size);
// compute logits of inputs if requested
outputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
if (i < batch_size_) {
// initialize next sub-batch
token_num = get_input_len(i);
max_input_len = get_input_len(i);
max_context_len = get_context_len(i);
offset = i;
decode_indices = {i};
decode_lengths = {get_input_len(i)};
}
}
else {
// add to current sub-batch
token_num += get_input_len(i);
max_input_len = std::max(max_input_len, get_input_len(i));
max_context_len = std::max(max_context_len, get_context_len(i));
decode_indices.push_back(i);
decode_lengths.push_back(get_input_len(i));
}
TM_LOG_INFO(
"[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
base,
sub_batch_size,
token_count,
max_input_len,
max_context_cnts[k]);
}
invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_);
invokePlusScalar(input_length_buf_ + base, 1, context_decode_count, stream_);
dbg(first, last);
dbg(k_block_ptrs_, v_block_ptrs_);
for (int i = offset; i < batch_size_; ++i) {
h_input_length_buf_[i] = 0;
if (1) {
std::vector<int> input_len(sub_batch_size);
std::vector<int> context_len(sub_batch_size);
Copy(input_length_buf_ + first, sub_batch_size, input_len.data());
Copy(context_length_buf_ + first, sub_batch_size, context_len.data());
cudaStreamSynchronize(stream_);
dbg(input_len, context_len);
}
check_cuda_error(cudaStreamSynchronize(stream_));
const auto tock = std::chrono::high_resolution_clock::now();
if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count());
model_->contextDecode(nullptr,
k_block_ptrs_,
v_block_ptrs_,
tmp_k_ptrs_ + first,
tmp_v_ptrs_ + first,
context_decoder_input_buf_,
context_decoder_output_buf_,
context_decoder_ids_buf_,
input_length_buf_ + first,
context_length_buf_ + first,
cu_block_counts_ + first,
rope_theta_ + first,
token_count,
max_input_len,
max_context_cnts[k],
max_context_cnts[k],
sub_batch_size);
// compute logits of inputs if requested
OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
}
invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_);
std::fill(h_input_length_buf_ + base, h_input_length_buf_ + batch_size, 0);
// `SequenceManager` needs real-time value of cache length
for (int i = base; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i] - 1; // -1 since we skip last token
}
}
else if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
check_cuda_error(cudaStreamSynchronize(stream_));
const auto tock = std::chrono::high_resolution_clock::now();
if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count());
}
}
template<typename T>
void LlamaBatch<T>::outputContextLogits(T* context_decoder_output,
void LlamaBatch<T>::OutputContextLogits(T* context_decoder_output,
const std::vector<int>& indices,
const std::vector<int>& lengths)
{
......@@ -894,7 +1128,7 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_
{
bool is_return_logits = false;
for (int k = 0; k < indices.size(); ++k) {
auto& request = requests_[indices[k]];
auto& request = state_->requests[indices[k]];
output_logits.push_back(request->outputs[rank_].getPtr<float>("logits", nullptr));
num_token += lengths[k];
if (output_logits.back()) {
......@@ -907,230 +1141,348 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_
}
if (context_logits_buf_ == nullptr) {
NcclGuard guard(llama_->tensor_para_, stream_, true);
NcclGuard guard(model_->tensor_para_, stream_, true);
context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
(float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ * max_context_token_num_);
const auto tp = model_->tensor_para_.world_size_;
if (tp > 1) {
FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
FT_CHECK(model_->vocab_size_padded_ % tp == 0);
const auto local_vocab_size = model_->vocab_size_padded_ / tp;
local_context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_);
}
}
llama_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token);
model_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token);
auto logits = context_logits_buf_;
for (int k = 0; k < indices.size(); ++k) {
if (output_logits[k]) {
check_cuda_error(cudaMemcpyAsync(output_logits[k],
logits,
sizeof(float) * llama_->vocab_size_ * lengths[k],
cudaMemcpyDefault,
stream_));
Copy(logits, model_->vocab_size_ * lengths[k], output_logits[k]);
}
logits += llama_->vocab_size_padded_ * lengths[k];
logits += model_->vocab_size_padded_ * lengths[k];
}
}
template<typename T>
void LlamaBatch<T>::finish()
auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
{
// secure info needed by `synchronize()`
check_cuda_error(
cudaMemcpyAsync(h_finished_buf_, finished_buf_, sizeof(bool) * batch_size_, cudaMemcpyDefault, stream_));
check_cuda_error(
cudaMemcpyAsync(h_sequence_lengths_, sequence_lengths_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_));
NvtxScope scope("Finish");
const int batch_size = state_->active_size;
setOutputTensors(step_);
// secure info needed by `Initialize()`
Copy(finished_buf_, batch_size, state_->h_finished);
check_cuda_error(cudaStreamSynchronize(stream_));
// invariant: context_length = sequence_length + 1
invokePlusScalar(sequence_lengths_, 1, batch_size, stream_);
Copy(sequence_lengths_, batch_size, state_->h_context_length);
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(1);
}
for (int i = 0; i < batch_size_; ++i) {
FT_CHECK(requests_[i] != nullptr);
if (requests_[i]->stream_cb && rank_ == 0) {
requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get());
if constexpr (0) {
std::unique_lock<std::mutex> lock;
if (rank_ == 0) {
NvtxScope _("acquire_outputs");
// wait for previous output operations
lock = std::unique_lock{output_mutex_};
output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
}
SetOutputTensors(g);
check_cuda_error(cudaStreamSynchronize(stream_));
if (rank_ == 0) {
NvtxScope _("signal_output_thread");
// enqueue new output requests
for (int i = 0; i < batch_size; ++i) {
FT_CHECK(state_->requests[i] != nullptr);
if (state_->requests[i]->stream_cb) {
output_reqs_.push_back(state_->requests[i]);
}
}
lock.unlock();
// notify output thread when we do have stream cbs to call
if (!output_reqs_.empty()) {
output_cv_.notify_one();
}
}
}
if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(0);
else {
SetOutputTensors(g);
check_cuda_error(cudaStreamSynchronize(stream_));
{
NvtxScope _("output_cb");
if (rank_ == 0 && model_->ffi_lock_) {
model_->ffi_lock_(1);
}
for (int i = 0; i < batch_size; ++i) {
FT_CHECK(state_->requests[i] != nullptr);
if (state_->requests[i]->stream_cb && rank_ == 0) {
state_->requests[i]->stream_cb(&state_->requests[i]->outputs[rank_].get());
}
}
if (rank_ == 0 && model_->ffi_lock_) {
model_->ffi_lock_(0);
}
}
}
if (debug_ && rank_ == 0) {
std::stringstream ss;
for (int i = 0; i < batch_size_; ++i) {
ss << (i ? ", " : "") << "(" << h_sequence_lengths_[i] << "," << h_finished_buf_[i] << ")";
for (int i = 0; i < batch_size; ++i) {
ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")";
}
TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
}
for (int i = 0; i < batch_size_; ++i) {
if (h_finished_buf_[i]) {
finishRequest(i, false);
++finished_count_;
// `SequenceManager` needs real-time value of cache length
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}
}
template<typename T>
void LlamaBatch<T>::synchronize()
{
// compact
int idx = 0;
for (int i = 0; i < batch_size_; ++i) {
if (requests_[i]) {
h_input_length_buf_[idx] = 0;
h_history_length_buf_[idx] = 0;
h_context_length_buf_[idx] = h_sequence_lengths_[i] + 1;
h_sequence_lengths_[idx] = h_context_length_buf_[idx];
check_cuda_error(cudaMemcpyAsync((curandState_t*)topk_curandstate_buf_ + idx,
llama_->dynamic_decode_layer_->topk_curandstate_buf() + i,
sizeof(curandState_t),
cudaMemcpyDefault,
stream_));
check_cuda_error(cudaMemcpyAsync((curandState_t*)topp_curandstate_buf_ + idx,
llama_->dynamic_decode_layer_->topp_curandstate_buf() + i,
sizeof(curandState_t),
cudaMemcpyDefault,
stream_));
if (i != idx) {
h_finished_buf_[idx] = h_finished_buf_[i];
request_seq_len_limit_[idx] = request_seq_len_limit_[i];
h_k_cache_ptr_buf_[idx] = h_k_cache_ptr_buf_[i];
h_v_cache_ptr_buf_[idx] = h_v_cache_ptr_buf_[i];
requests_[idx] = std::move(requests_[i]);
cached_seq_[idx] = std::move(cached_seq_[i]);
check_cuda_error(cudaMemcpyAsync(output_ids_buf_ + idx * session_len_,
output_ids_buf_ + i * session_len_,
sizeof(int) * h_context_length_buf_[idx],
cudaMemcpyDefault,
stream_));
std::vector<Signal> signals;
{
NvtxScope _("prepare_completion_signal");
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i] && state_->h_finished[i]) {
CompleteRequest(i, false, false);
signals.push_back([r = std::move(state_->requests[i])] { r->signal.set_value(0); });
}
++idx;
}
}
batch_size_ = idx;
if (rank_ == 0) {
TM_LOG_INFO("[synchronize] batch_size = %d", (int)batch_size_);
}
finished_count_ = 0;
return signals;
}
template<typename T>
void LlamaBatch<T>::setOutputTensors(int max_gen_step)
void LlamaBatch<T>::SetOutputTensors(const GenerationState& g)
{
NvtxScope scope("SetOutputTensors");
// dbg(g.max_init_ctx_len);
const auto batch_size = state_->active_size;
// [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
invokeGatherOutput(output_ids_buf_,
invokeGatherOutput(state_->output_ids,
token_ids_buf_,
context_length_buf_,
max_context_len_,
max_gen_step,
g.max_init_ctx_len,
g.step,
session_len_,
batch_size_,
batch_size,
stream_);
sync_check_cuda_error();
/// TODO: fuse the loop into a single kernel
for (int i = 0; i < batch_size_; ++i) {
if (requests_[i]) {
auto& output_ids = requests_[i]->outputs[rank_].at("output_ids");
auto& sequence_length = requests_[i]->outputs[rank_].at("sequence_length");
check_cuda_error(cudaMemcpyAsync(output_ids.getPtr<int>(),
output_ids_buf_ + i * session_len_,
sizeof(int) * output_ids.shape.at(2),
cudaMemcpyDefault,
stream_));
check_cuda_error(cudaMemcpyAsync(
sequence_length.getPtr<int>(), sequence_lengths_ + i, sizeof(int), cudaMemcpyDefault, stream_));
if (max_gen_step > max_context_len_) { // +1 for newly generated token
invokePlusScalar(sequence_length.getPtr<int>(), 1, 1, stream_);
}
}
if constexpr (1) {
invokeUpdateOutput(request_output_ids_ptrs_,
request_seqlen_ptrs_,
state_->output_ids,
sequence_lengths_,
request_output_ids_lens_,
session_len_,
g.step > g.max_init_ctx_len,
batch_size,
stream_);
sync_check_cuda_error();
}
else {
// for (int i = 0; i < batch_size; ++i) {
// if (state_->requests[i]) {
// auto& output_ids = state_->requests[i]->outputs[rank_].at("output_ids");
// auto& sequence_length = state_->requests[i]->outputs[rank_].at("sequence_length");
// Copy(state_->output_ids + i * session_len_, output_ids.shape.at(2), output_ids.getPtr<int>());
// Copy(sequence_lengths_ + i, 1, sequence_length.getPtr<int>());
// if (g.step > g.max_init_ctx_len) { // +1 for newly generated token
// invokePlusScalar(sequence_length.getPtr<int>(), 1, 1, stream_);
// }
// }
// }
}
}
template<typename T>
void LlamaBatch<T>::finishRequest(int index, bool force_end)
void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_force_end)
{
if (rank_ == 0) {
TM_LOG_INFO("[finishRequest] slot = %d, id = %lu", index, (long)requests_[index]->id);
TM_LOG_INFO("[CompleteRequest] slot = %d, id = %lu", index, (long)state_->requests[index]->id);
}
if (debug_ && rank_ == 0) {
std::vector<int> tokens(h_sequence_lengths_[index] + 1);
cudaMemcpyAsync(tokens.data(),
output_ids_buf_ + index * session_len_,
sizeof(int) * tokens.size(),
cudaMemcpyDefault,
stream_);
std::vector<int> tokens(state_->h_context_length[index]);
Copy(state_->output_ids + index * session_len_, tokens.size(), tokens.data());
cudaStreamSynchronize(stream_);
std::stringstream ss;
for (const auto& t : tokens) {
ss << " " << t;
}
TM_LOG_INFO("[finishRequest] slot %d, tokens [%s]", index, ss.str().c_str());
TM_LOG_INFO("[CompleteRequest] slot %d, tokens [%s]", index, ss.str().c_str());
}
auto& output_ids_tensor = requests_[index]->outputs[rank_].at("output_ids");
const auto output_ids_data = output_ids_tensor.getPtr<int>();
if (requests_[index]->end_flag || force_end) {
llama_->kv_cache_mgr_->erase(requests_[index]->id);
if (state_->requests[index]->end_flag || is_force_end) {
sequence_manager_->Erase(state_->requests[index]->id);
}
else {
// the last generated token is not processed by decoder thus dont have k/v cache
const int n_steps = step_ - max_context_len_;
const int cache_len = h_sequence_lengths_[index];
const int output_len = n_steps > 0 ? cache_len + 1 : cache_len;
auto& seq = cached_seq_[index];
// account for the last generated token if not a stop request (which doesn't generate)
const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(is_stop_request);
seq.cache_len = cache_len;
auto& seq = *state_->sequences[index];
// update token IDs
seq.token_ids.resize(output_len);
check_cuda_error(cudaMemcpyAsync(
seq.token_ids.data(), output_ids_data, sizeof(int) * output_len, cudaMemcpyDefault, stream_));
seq.tokens.resize(output_len);
const auto output_ids_data = state_->requests[index]->outputs[rank_].at("output_ids").getPtr<int>();
Copy(output_ids_data, output_len, seq.tokens.data());
// update random states
seq.random_state_.resize(sizeof(curandState_t) * 2);
check_cuda_error(cudaMemcpyAsync(seq.random_state_.data(),
llama_->dynamic_decode_layer_->topk_curandstate_buf() + index,
sizeof(curandState_t),
cudaMemcpyDefault,
stream_));
check_cuda_error(cudaMemcpyAsync(seq.random_state_.data() + sizeof(curandState_t),
llama_->dynamic_decode_layer_->topp_curandstate_buf() + index,
sizeof(curandState_t),
cudaMemcpyDefault,
stream_));
seq.random_state.resize(sizeof(curandState_t) * 2);
// save random state in host memory
if (auto ptr = (curandState_t*)seq.random_state.data()) {
ptr = Copy(model_->GetTopKState(index), 1, ptr);
ptr = Copy(model_->GetTopPState(index), 1, ptr);
}
check_cuda_error(cudaStreamSynchronize(stream_));
llama_->kv_cache_mgr_->update(cached_seq_[index], stream_);
sequence_manager_->UpdateAndSetUnlock(seq);
}
state_->sequences[index] = nullptr;
}
template<typename T>
void LlamaBatch<T>::InternalThreadEntry(int device_id)
{
TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_);
check_cuda_error(cudaSetDevice(device_id));
auto& shared_state = model_->shared_state_;
auto& request_queue = shared_state->request_queue;
auto& infer_requests = shared_state->infer_requests;
auto& stop_requests = shared_state->stop_requests;
int finished_count = 0;
GenerationState g{};
while (1) {
if (rank_ == 0) {
const int free_slot_count = max_batch_size_ - state_->size + finished_count;
const bool is_empty = (free_slot_count == max_batch_size_);
// will block if batch is empty
request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state->abort);
if (!shared_state->abort) {
RejectInvalidRequests(stop_requests, infer_requests);
}
}
NvtxScope scope("mainloop");
// wait while rank-0 is dequeueing
shared_state->barrier->wait();
if (shared_state->abort) {
TM_LOG_INFO("[InternalThreadEntry] stop requested.");
// if (state_->size && rank_ == 0) {
// TM_LOG_WARNING("Active request(s) present (%d) while exiting.", state_->size);
// }
return;
}
auto signals = ProcessStopRequests(stop_requests);
BarrierSignalRequests(*shared_state->barrier, signals);
ProcessInferRequests(infer_requests);
// wait while shared stop/infer_requests is being used
shared_state->barrier->wait();
auto modified = Initialize();
ContextDecode();
if (state_->active_size) {
if (modified) {
g = InitializeGeneration();
InitializeSampling();
}
for (int i = 0; i < step_length_; ++i) {
if (!Generate(g)) {
break;
}
}
auto signals = Finish(g);
finished_count = signals.size();
BarrierSignalRequests(*shared_state->barrier, signals);
}
}
// When the signal is set threads from LlamaV2::forward can exit
// and free inputs/outputs tensors.
// Therefore we need to make sure that no threads from LlamaV2::internalThreadEntry
// are accessing the tensors.
llama_->shared_state_->barrier->wait();
FT_CHECK(0);
}
template<typename T>
void LlamaBatch<T>::BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals)
{
if (!signals.empty()) {
barrier.wait();
if (rank_ == 0) {
std::for_each(signals.cbegin(), signals.cend(), [](auto& s) { s(); });
}
barrier.wait();
}
}
template<typename T>
void LlamaBatch<T>::Start()
{
TM_LOG_ERROR("LlamaBatch<T>::Start()");
int device_id = -1;
check_cuda_error(cudaGetDevice(&device_id));
internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this, device_id);
if (rank_ == 0) {
requests_[index]->signal.set_value(0);
output_thread_ = std::thread(&LlamaBatch::OutputThreadEntry, this);
}
}
requests_[index] = nullptr;
template<typename T>
void LlamaBatch<T>::OutputThreadEntry()
{
while (true) {
{
// wait for requests with stream cbs
std::unique_lock lock(output_mutex_);
output_cv_.wait(lock, [&] { return !output_reqs_.empty() || output_stop_token_; });
// NvtxScope _("output_callback");
// stop requested
if (output_stop_token_) {
TM_LOG_INFO("[OutputThreadEntry] stop requested.");
return;
}
if (rank_ == 0 && model_->ffi_lock_) {
TM_LOG_INFO("acquire GIL");
model_->ffi_lock_(1);
TM_LOG_INFO("acquire GIL success");
}
// invoke stream cbs
for (const auto& r : output_reqs_) {
r->stream_cb(&r->outputs[rank_].get());
}
if (rank_ == 0 && model_->ffi_lock_) {
TM_LOG_INFO("release GIL");
model_->ffi_lock_(0);
TM_LOG_INFO("release GIL success");
}
output_reqs_.clear();
}
FT_CHECK(output_reqs_.empty());
// notify infer thread 0
output_cv_.notify_one();
}
}
template class LlamaBatch<half>;
......
......@@ -2,66 +2,139 @@
#pragma once
#include "src/turbomind/models/llama/LlamaCacheManager.h"
// #include "src/turbomind/models/llama/LlamaCacheManager.h"
#include "src/turbomind/models/llama/Barrier.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include <condition_variable>
#include <mutex>
namespace turbomind {
struct BatchState {
int* h_context_length;
bool* h_finished;
void* top_k_curand_state;
void* top_p_curand_state;
int* output_ids; // output ids in [B, S]
float* h_rope_theta;
std::vector<int> seq_len_limit;
std::vector<int> is_swap_in;
std::vector<const Sequence*> sequences;
std::vector<std::shared_ptr<Request>> requests;
// |<-- existing -->|<-- swap-in -->|
// |<----------- active ----------->|<-- inactive -->|
int active_size;
int size;
};
template<typename T>
class LlamaV2;
template<typename T>
class LlamaBatch {
public:
int size() const noexcept
{
return batch_size_;
};
void AllocateBuffer(size_t batch_size, size_t session_len);
void AllocatePersistantBuffer(size_t max_batch_size);
void FreeBuffer();
int maxSize() const noexcept
{
return max_batch_size_;
}
using Requests = std::vector<std::shared_ptr<Request>>;
using Signal = std::function<void()>;
int finishedCount() const noexcept
{
return finished_count_;
}
void RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs);
[[nodiscard]] auto ProcessStopRequests(const Requests& requests) -> std::vector<Signal>;
void ProcessInferRequests(const Requests& requests);
void verifyRequests(std::vector<std::shared_ptr<Request>>& stop_reqs,
std::vector<std::shared_ptr<Request>>& infer_reqs);
void handleStopRequests(const std::vector<std::shared_ptr<Request>>& requests);
[[nodiscard]] bool Initialize();
void allocateBuffer(size_t batch_size, size_t session_len);
void allocatePersistantBuffer(size_t max_batch_size);
void freeBuffer();
void ContextDecode();
void initializeSampling(int infer_request_count);
struct GenerationState {
int max_init_ctx_len;
int step;
int sum_seq_len;
int max_seq_len;
};
void initialize(const std::vector<std::shared_ptr<Request>>& infer_requests);
void contextDecode();
void InitializeSampling();
GenerationState InitializeGeneration();
void initializeGeneration();
bool generate();
[[nodiscard]] bool Generate(GenerationState& g);
void finish();
void finishRequest(int index, bool force_end);
[[nodiscard]] auto Finish(GenerationState& g) -> std::vector<Signal>;
void synchronize();
void CompleteRequest(int index, bool is_stop_request, bool is_force_end);
void setOutputTensors(int max_gen_step);
void SetOutputTensors(const GenerationState& g);
void
outputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
explicit LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2<T>* llama);
explicit LlamaBatch(int max_batch_size,
int max_context_token_num,
int session_len,
std::unique_ptr<SequenceManager> sequence_manager,
LlamaV2<T>* llama);
~LlamaBatch()
{
freeBuffer();
TM_LOG_ERROR("~LlamaBatch()");
model_->shared_state_->request_queue.close();
internal_thread_.join();
if (output_thread_.joinable()) {
{
std::lock_guard lock{output_mutex_};
output_stop_token_ = true;
}
output_cv_.notify_one();
output_thread_.join();
}
FreeBuffer();
}
void Start();
private:
void InternalThreadEntry(int device_id);
void OutputThreadEntry();
void UpdateSequenceStates(BatchState& state, int index);
void CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst);
void SaveRandomState(BatchState& state, int idx);
void LoadRandomState(BatchState& state, int idx);
void BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals);
// analogs to `std::copy_n`
template<typename U>
U* Copy(const U* src, size_t count, U* dst)
{
check_cuda_error(cudaMemcpyAsync(dst, src, sizeof(U) * count, cudaMemcpyDefault, stream_));
return dst += count;
}
template<typename U>
U* Clear(U* data, size_t count)
{
check_cuda_error(cudaMemsetAsync(data, 0, sizeof(U) * count, stream_));
return data += count;
}
private:
......@@ -70,52 +143,67 @@ private:
const int session_len_;
const int rank_;
const bool debug_;
const int step_length_;
LlamaV2<T>* const llama_;
// active requests
std::vector<std::shared_ptr<Request>> requests_;
T* context_decoder_input_buf_{}; // CTXDEC
T* context_decoder_output_buf_{}; // CTXDEC
int* context_decoder_ids_buf_{};
T* decoder_input_buf_{}; // CTXDEC, GENERATE
T* decoder_output_buf_{}; // CTXDEC, GENERATE
LlamaV2<T>* const model_;
int* input_ids_buf_{}; // input token ids + cache missed token ids, CTXDEC
int* input_length_buf_{}; // input + cache missed length, CTXDEC, GENERATE
int* history_length_buf_{}; // history length, CTXDEC
int* context_length_buf_{}; // history length + input_length, CTXDEC, GENERATE
std::unique_ptr<SequenceManager> sequence_manager_;
int* total_padding_count_{}; // GENERATE
int* sequence_lengths_{}; // current sequence length
///////////////////////////////////////////////////////////////////
// k/v cache block buffers
int* cu_block_counts_{};
uintptr_t* k_block_ptrs_{};
uintptr_t* v_block_ptrs_{};
uint64_t* k_cache_ptr_buf_{};
uint64_t* v_cache_ptr_buf_{};
////////////////////////////////////////////////////////////////////
// context decoding temp buffers
T* context_decoder_input_buf_{};
T* context_decoder_output_buf_{};
int* context_decoder_ids_buf_{};
int* input_ids_buf_{};
// lengths
int* input_length_buf_{}; // input + cache missed length
int* context_length_buf_{}; // history length + input_length
// temp buffers used for block->linear kv-cache conversion
T* tmp_k_cache_buf_{};
T* tmp_v_cache_buf_{};
void** tmp_k_ptrs_{};
void** tmp_v_ptrs_{};
void** h_tmp_k_ptrs_{};
void** h_tmp_v_ptrs_{};
T* decoder_input_buf_{};
T* decoder_output_buf_{};
int* sequence_lengths_{}; // current sequence length
int* init_ctx_lens_{};
float* logits_buf_{}; // combined logits
float* local_logits_buf_{}; // tensor parallel local logits
float* context_logits_buf_{};
float* local_context_logits_buf_{};
float* rope_theta_{};
// used by dynamic decoder
int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step`
int* output_ids_buf_{}; // output ids in [B, S]
int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step`
int* end_ids_buf_{};
bool* finished_buf_{};
uint32_t* seq_limit_len_{};
int** request_output_ids_ptrs_{};
int* request_output_ids_lens_{};
int** request_seqlen_ptrs_{};
int** h_request_output_ids_ptrs_{};
int* h_request_output_ids_lens_{};
int** h_request_seqlen_ptrs_{};
// pinned buffers
int* h_input_ids_buf_{};
int* h_input_length_buf_{};
int* h_history_length_buf_{};
int* h_context_length_buf_{};
int* h_sequence_lengths_{};
bool* h_finished_buf_{};
uintptr_t* h_k_cache_ptr_buf_{};
uintptr_t* h_v_cache_ptr_buf_{};
uint32_t* h_seq_limit_len_{};
int* h_cu_block_counts_{};
uintptr_t* h_k_block_ptrs_{};
uintptr_t* h_v_block_ptrs_{};
int* stop_words_buf_{}; // [batch_size, 2, kMaxStopWordsLen]
int* bad_words_buf_{};
......@@ -125,24 +213,19 @@ private:
float* h_repetition_penalty_{};
uint64_t* h_random_seed_{};
void* topk_curandstate_buf_{};
void* topp_curandstate_buf_{};
std::array<BatchState, 3> states_{};
// hard limits for persistent buffers
static constexpr int kMaxStopBadWordsLen = 32;
BatchState* state_{};
BatchState* back_{};
BatchState* incoming_{};
using CachedSeq = LlamaCacheManager::Sequence;
uint64_t request_count_{0};
std::vector<CachedSeq> cached_seq_;
std::vector<int> request_seq_len_limit_;
// hard limits for persistent buffers
static constexpr int kMaxStopBadWordsLen = 32;
const DataType data_type_{};
int batch_size_{};
int max_context_len_{};
int step_{};
int finished_count_{};
bool is_allocate_persistant_buffer_ = false;
bool is_allocate_buffer_ = false;
......@@ -154,6 +237,15 @@ private:
cudaStream_t stream_{};
cublasMMWrapper* cublas_wrapper_{};
IAllocator* allocator_{};
std::thread internal_thread_;
// async stream callback utils
std::thread output_thread_;
std::mutex output_mutex_;
std::condition_variable output_cv_;
Requests output_reqs_;
bool output_stop_token_{false};
};
} // namespace turbomind
......@@ -21,6 +21,7 @@
#include "src/turbomind/models/llama/LlamaContextAttentionLayer.h"
#include "src/turbomind/kernels/bert_preprocess_kernels.h"
#include "src/turbomind/kernels/decoder_multihead_attention/kv_cache.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
......@@ -28,6 +29,7 @@
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/logger.h"
namespace turbomind {
......@@ -116,6 +118,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
* \param history_lengths [batch_size], int
* \param context_lengths [batch_size], int
* \param cu_seqlens [batch_size+1], int
* \param cu_block_counts [batch_size+1], int
* \param max_seq_len [1], int on cpu
* \param is_final_layer [1], bool on cpu
* \param layer_id [1], int on cpu
......@@ -141,13 +144,23 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
T* attention_input = input_tensors->at("input_query").getPtr<T>();
T* attention_mask = input_tensors->at("attention_mask").getPtr<T>();
const auto input_length = input_tensors->at("input_lengths").getPtr<const int>();
const auto history_length = input_tensors->at("history_lengths").getPtr<const int>();
const auto context_length = input_tensors->at("context_lengths").getPtr<const int>();
int* cu_seqlens = input_tensors->at("cu_seqlens").getPtr<int>();
const auto input_length = input_tensors->at("input_lengths").getPtr<const int>();
const auto context_length = input_tensors->at("context_lengths").getPtr<const int>();
int* cu_seqlens = input_tensors->at("cu_seqlens").getPtr<int>();
int* cu_block_counts = input_tensors->at("cu_block_counts").getPtr<int>();
const float* rope_theta = input_tensors->getPtr<const float>("rope_theta", nullptr);
const auto padding_offset = input_tensors->at("padding_offset").getPtr<int>();
auto Show = [&](const T* x, size_t n) {
std::vector<T> vec(n);
cudaMemcpyAsync(vec.data(), x, sizeof(T) * n, cudaMemcpyDefault, stream_);
cudaStreamSynchronize(stream_);
std::vector<float> float_vec(vec.begin(), vec.end());
dbg(float_vec);
};
/////////////////////////////////////////////
/// allocate buffers
allocateBuffer(batch_size, num_token, max_q_len, max_k_len);
......@@ -166,26 +179,32 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
qkv_buf_,
weights->qkv.bias,
padding_offset, // padding_offset,
history_length, // used for applying rotary embedding
context_length, // used for applying rotary embedding
input_length,
rope_theta,
batch_size,
max_q_len, // seq_len
num_token, // batch_size * seq_len
local_head_num_,
local_kv_head_num_,
size_per_head_,
params_.rotray_embedding_dim,
params_.rotary_embedding_dim,
params_.rotary_embedding_base,
params_.max_position_embeddings,
params_.use_dynamic_ntk,
false, // params_.use_dynamic_ntk,
params_.use_logn_attn,
stream_);
sync_check_cuda_error();
const size_t layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
// [2, L, H, s, D]
const size_t layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_;
auto k_cache_ptrs = output_tensors->getPtr<void*>("key_cache");
auto v_cache_ptrs = output_tensors->getPtr<void*>("value_cache");
auto tmp_k_ptrs = output_tensors->getPtr<T*>("tmp_k");
auto tmp_v_ptrs = output_tensors->getPtr<T*>("tmp_v");
auto k_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
auto v_cache_ptrs = output_tensors->getPtr<T*>("value_cache");
//////////////////////////////////////////////////////////
/// insert the k/v computed from inputs into k/v cache
/// transpose kv -> kv cache
......@@ -194,25 +213,53 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
// v_buf_2 [B, kvH, s, D] -> val_cache [B, kvH, S[t:t+s], D/x, x]
invokeExtendKVCache(k_cache_ptrs,
v_cache_ptrs,
layer_offset,
k_buf_2_,
v_buf_2_,
batch_size,
cu_block_counts,
input_length,
context_length,
batch_size,
kv_cache_block_len_,
layer_offset,
max_q_len,
history_length,
max_seq_len,
size_per_head_,
local_kv_head_num_,
stream_,
quant_policy_,
weights->past_kv_scale.data());
weights->past_kv_scale.data(),
stream_);
sync_check_cuda_error();
const int kv_cache_elem_bits = quant_policy_ & QuantPolicy::kCacheKVInt8 ? 8 : sizeof(T) * 8;
ConvertKvCacheBlocksToLinear2((const void**)k_cache_ptrs,
(const void**)v_cache_ptrs,
(T**)tmp_k_ptrs,
(T**)tmp_v_ptrs,
cu_block_counts,
context_length,
layer_offset,
kv_cache_block_len_,
max_seq_len,
local_kv_head_num_,
size_per_head_,
batch_size,
quant_policy_,
weights->past_kv_scale.data(),
stream_);
sync_check_cuda_error();
// dbg(kv_cache_block_len_, max_seq_len, local_kv_head_num_, size_per_head_, batch_size);
// void *kk, *vv;
// cudaMemcpyAsync(&kk, tmp_k_ptrs, sizeof(void*), cudaMemcpyDefault, stream_);
// cudaMemcpyAsync(&vv, tmp_v_ptrs, sizeof(void*), cudaMemcpyDefault, stream_);
// cudaStreamSynchronize(stream_);
// Show((const T*)kk, local_kv_head_num_ * max_seq_len * size_per_head_);
// Show((const T*)vv, local_kv_head_num_ * max_seq_len * size_per_head_);
if (use_fmha_) {
fusedMultiHeadAttention(k_cache_ptrs,
v_cache_ptrs,
layer_offset,
fusedMultiHeadAttention(tmp_k_ptrs,
tmp_v_ptrs,
0,
attention_mask,
cu_seqlens,
input_tensors->at("context_lengths").getPtr<int>(),
......@@ -222,9 +269,9 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
max_seq_len);
}
else {
unfusedMultiHeadAttention(k_cache_ptrs,
v_cache_ptrs,
layer_offset,
unfusedMultiHeadAttention(tmp_k_ptrs,
tmp_v_ptrs,
0,
attention_mask,
padding_offset,
context_length,
......@@ -237,6 +284,14 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
weights->past_kv_scale.data());
}
// Compare(qkv_buf_3_, num_token * hidden_units_, Concat("qkv_buf_3", layer_id), kCmpRead, stream_);
// dbg(max_seq_len);
if (0) {
Show(qkv_buf_3_, num_token * hidden_units_);
}
//////////////////////////////////////////////
/// output gemm <Bs,HD> -> <Bs,HD>
linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output);
......@@ -342,7 +397,7 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_c
local_head_num_,
head_n_rep_,
stream_,
quant,
0, // dequant handled in block->linear conversion
kv_scale);
sync_check_cuda_error();
......
......@@ -45,6 +45,7 @@ public:
IAllocator* allocator,
bool is_free_buffer_after_forward,
bool use_fmha,
int cache_block_seq_len,
int quant_policy):
head_num_(head_num),
size_per_head_(size_per_head),
......@@ -58,6 +59,7 @@ public:
cublas_wrapper_(cublas_wrapper),
linear_(cublas_wrapper, stream),
allocator_(allocator),
kv_cache_block_len_(cache_block_seq_len),
is_free_buffer_after_forward_(is_free_buffer_after_forward),
use_fmha_(use_fmha),
quant_policy_(quant_policy)
......@@ -99,6 +101,7 @@ private:
const size_t local_kv_head_num_;
const size_t local_head_num_;
const size_t head_n_rep_;
const size_t kv_cache_block_len_;
const bool is_free_buffer_after_forward_;
const LlamaAttentionParams params_;
......
......@@ -25,7 +25,9 @@
#include "src/turbomind/models/llama/LlamaContextDecoder.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/debug_utils.h"
namespace turbomind {
......@@ -64,6 +66,7 @@ template<typename T>
void LlamaContextDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
bool use_fmha,
int cache_block_seq_len,
int quant_policy)
{
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
......@@ -78,6 +81,7 @@ void LlamaContextDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
allocator_,
is_free_buffer_after_forward_,
use_fmha,
cache_block_seq_len,
quant_policy);
silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
......@@ -93,6 +97,7 @@ void LlamaContextDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
template<typename T>
void LlamaContextDecoder<T>::forwardSelfAttn(const Session& sess,
T* attn_io,
std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors,
int layer,
bool is_final)
......@@ -107,18 +112,17 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
{"padding_offset", {MEMORY_GPU, TYPE_INT32, {sess.token_num}, padding_offset_}},
{"cu_seqlens", {MEMORY_GPU, TYPE_INT32, {sess.batch_size + 1}, cu_seqlens_}},
{"input_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.input_length}},
{"history_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.history_length}},
{"context_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.context_length}},
{"cu_block_counts", input_tensors->at("cu_block_counts")},
{"rope_theta", input_tensors->at("rope_theta")},
{"max_seq_len", input_tensors->at("max_seq_len")}};
auto& k_cache = *sess.k_cache;
auto& v_cache = *sess.v_cache;
TensorMap self_attention_output_tensors{
{"hidden_features", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
{"key_cache", k_cache},
{"value_cache", v_cache},
};
{"key_cache", output_tensors->at("key_cache")},
{"value_cache", output_tensors->at("value_cache")},
{"tmp_k", output_tensors->at("tmp_k")},
{"tmp_v", output_tensors->at("tmp_v")}};
context_attention_layer_->forward(&self_attention_output_tensors, //
&self_attention_input_tensors,
......@@ -139,6 +143,7 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num
IAllocator* allocator,
bool is_free_buffer_after_forward,
bool use_fmha,
int cache_block_seq_len,
int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num),
......@@ -150,7 +155,7 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num
tensor_para_(tensor_para),
data_type_(getTensorType<T>())
{
initialize(attn_params, kv_head_num, use_fmha, quant_policy);
initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy);
}
template<typename T>
......@@ -201,17 +206,16 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
sess.weights = decoder_layer_weights;
sess.input_length = input_tensors->at("input_lengths").getPtr<int>();
sess.history_length = input_tensors->at("history_lengths").getPtr<int>();
sess.context_length = input_tensors->at("context_lengths").getPtr<int>();
T* decoder_input_output = input_tensors->at("decoder_input").getPtr<T>();
T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();
sess.k_cache = &output_tensors->at("key_cache");
sess.v_cache = &output_tensors->at("value_cache");
allocateBuffer(sess.batch_size, sess.token_num, sess.max_query_len, sess.max_key_len);
// dbg(padding_offset_);
FT_CHECK(padding_offset_);
size_t tmp_token_num{};
invokeGetPaddingOffsetAndCuSeqLens(h_pinned_token_num_ptr_,
&tmp_token_num, // updated token num
......@@ -222,6 +226,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
sess.max_query_len,
stream_);
sync_check_cuda_error();
dbg(tmp_token_num, sess.token_num);
FT_CHECK(tmp_token_num == sess.token_num);
invokeCreateCausalMasks(attention_mask_,
......@@ -233,6 +238,9 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
stream_);
sync_check_cuda_error();
// Compare(
// decoder_input_output, sess.token_num * hidden_units_, Concat("context_decoder_input", 0), kCmpRead, stream_);
/////////////////////////////////////////////
/// RMSNorm
invokeRootMeanSquareNorm(decoder_output,
......@@ -247,7 +255,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
for (size_t layer = 0; layer < num_layer_; ++layer) {
/////////////////////////////////////////////
/// self-attention
forwardSelfAttn(sess, decoder_output, input_tensors, layer, false);
forwardSelfAttn(sess, decoder_output, output_tensors, input_tensors, layer, false);
invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
decoder_output,
......
......@@ -40,7 +40,11 @@ protected:
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
void freeBuffer() override;
void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, bool use_fmha, int quant_policy);
void initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
bool use_fmha,
int cache_block_seq_len,
int quant_policy);
size_t head_num_;
size_t size_per_head_;
......@@ -63,21 +67,19 @@ protected:
const DataType data_type_;
struct Session {
size_t batch_size;
size_t token_num;
size_t max_query_len;
size_t max_key_len;
Tensor* k_cache;
Tensor* v_cache;
int* input_length{};
int* history_length{};
int* context_length{};
size_t batch_size;
size_t token_num;
size_t max_query_len;
size_t max_key_len;
int* input_length{};
int* context_length{};
const std::vector<LlamaDecoderLayerWeight<T>*>* weights;
};
void forwardSelfAttn(const Session& sess,
T* attn_io,
std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors,
int layer,
bool is_final);
......@@ -96,6 +98,7 @@ public:
IAllocator* allocator,
bool is_free_buffer_after_forward,
bool use_fmha,
int cache_block_seq_len,
int quant_policy);
~LlamaContextDecoder() override;
......
......@@ -41,6 +41,7 @@ LlamaDecoder<T>::LlamaDecoder(size_t head_num,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
int cache_block_seq_len,
int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num),
......@@ -53,7 +54,7 @@ LlamaDecoder<T>::LlamaDecoder(size_t head_num,
data_type_(getTensorType<T>())
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
initialize(attn_params, kv_head_num, quant_policy);
initialize(attn_params, kv_head_num, cache_block_seq_len, quant_policy);
}
template<typename T>
......@@ -65,7 +66,10 @@ LlamaDecoder<T>::~LlamaDecoder()
}
template<typename T>
void LlamaDecoder<T>::initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy)
void LlamaDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
int cache_block_seq_len,
int quant_policy)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
......@@ -78,6 +82,7 @@ void LlamaDecoder<T>::initialize(const LlamaAttentionParams& attn_params, size_t
cublas_wrapper_,
allocator_,
is_free_buffer_after_forward_,
cache_block_seq_len,
quant_policy);
silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
......@@ -118,6 +123,7 @@ void LlamaDecoder<T>::forwardSelfAttn(const LlamaDecoder::Session&
const std::unordered_map<std::string, Tensor>* input_tensors,
size_t layer)
{
NvtxScope scope("self_attn");
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
TensorMap self_attention_input_tensors(*input_tensors);
self_attention_input_tensors.insert("input_query",
......@@ -180,60 +186,73 @@ void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ou
// for the shape of key cache, refer to decoder_masked_multihead_attention_template.hpp
NvtxScope forward_scope("decoder_forward");
Session sess{};
sess.batch_size = input_tensors->at("decoder_input").shape[0];
sess.weights = decoder_layer_weights;
allocateBuffer(sess.batch_size);
sess.ite = input_tensors->at("ite").getVal<const int>();
sess.k_cache = &output_tensors->at("key_cache");
sess.v_cache = &output_tensors->at("value_cache");
sess.max_memory_len = input_tensors->at("max_seq_len").getVal<int>();
T* decoder_input = input_tensors->at("decoder_input").getPtr<T>();
T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();
int step = input_tensors->at("step").getVal<int>();
// Compare(decoder_input, sess.batch_size * hidden_units_, Concat("decoder_input", 0, step), kCmpRead, stream_);
////////////////////////////////////////////
/// RMSNorm
invokeRootMeanSquareNorm(decoder_output,
decoder_input,
decoder_layer_weights->at(0)->self_attn_norm_weights,
rmsnorm_eps_,
sess.batch_size,
hidden_units_,
stream_);
sync_check_cuda_error();
{
NvtxScope rms_norm_scope("rms_norm_0");
invokeRootMeanSquareNorm(decoder_output,
decoder_input,
decoder_layer_weights->at(0)->self_attn_norm_weights,
rmsnorm_eps_,
sess.batch_size,
hidden_units_,
stream_);
sync_check_cuda_error();
}
for (size_t layer = 0; layer < num_layer_; ++layer) {
NvtxScope layer_scope("decode_layer");
// output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
forwardSelfAttn(sess, decoder_output, input_tensors, layer);
invokeFusedAddBiasResidualRMSNorm(decoder_input,
decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
decoder_layer_weights->at(layer)->ffn_norm_weights,
rmsnorm_eps_,
sess.batch_size,
hidden_units_,
stream_);
sync_check_cuda_error();
{
NvtxScope rms_norm_scope("rms_norm_1");
invokeFusedAddBiasResidualRMSNorm(decoder_input,
decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
decoder_layer_weights->at(layer)->ffn_norm_weights,
rmsnorm_eps_,
sess.batch_size,
hidden_units_,
stream_);
sync_check_cuda_error();
}
// decoder_layer_output_ = ffn(decoder_normed_input_)
forwardFfn(sess, decoder_output, layer);
auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddBiasResidualRMSNorm(decoder_input, //
decoder_output,
decoder_layer_weights->at(layer)->ffn_weights.output.bias,
scale_weight,
rmsnorm_eps_,
sess.batch_size,
hidden_units_,
stream_);
sync_check_cuda_error();
{
NvtxScope rms_norm_scope("rms_norm_2");
auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddBiasResidualRMSNorm(decoder_input, //
decoder_output,
decoder_layer_weights->at(layer)->ffn_weights.output.bias,
scale_weight,
rmsnorm_eps_,
sess.batch_size,
hidden_units_,
stream_);
sync_check_cuda_error();
}
}
if (is_free_buffer_after_forward_) {
......
......@@ -35,7 +35,8 @@ protected:
void allocateBuffer() override; // deprecated
void allocateBuffer(size_t batch_size);
void freeBuffer() override;
void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy);
void
initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int cache_block_seq_len, int quant_policy);
size_t head_num_;
size_t size_per_head_;
......@@ -53,8 +54,6 @@ protected:
struct Session {
size_t batch_size;
int ite;
size_t max_memory_len;
Tensor* k_cache;
Tensor* v_cache;
const std::vector<LlamaDecoderLayerWeight<T>*>* weights;
......@@ -80,6 +79,7 @@ public:
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
int cache_block_seq_len,
int quant_policy);
~LlamaDecoder() override;
......
......@@ -302,7 +302,7 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
self_attn_weights.past_kv_scale = loadArrayFromBin({4}, scale_path);
}
else {
self_attn_weights.past_kv_scale = {};
self_attn_weights.past_kv_scale = {1.f, 0.f, 1.f, 0.f};
}
}
......
......@@ -19,6 +19,7 @@
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/DecoderSelfAttentionLayer.cc
#include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
#include "src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/llama_kernels.h"
......@@ -32,141 +33,7 @@
namespace turbomind {
template<typename T>
struct SATypeConverter {
using Type = T;
};
template<>
struct SATypeConverter<half> {
using Type = uint16_t;
};
template<typename T>
static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const T* qkv_bias,
const T* relative_attention_bias,
T* key_cache,
T* value_cache,
T** k_cache_per_sample,
T** v_cache_per_sample,
size_t kv_cache_per_sample_offset,
const int* cache_indir,
T* context_buf,
const bool* finished,
const int* sequence_lengths,
const int max_batch_size,
const int inference_batch_size,
const int beam_width,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const float rotary_embedding_base,
const int max_position_embeddings,
const bool use_dynamic_ntk,
const bool use_logn_attn,
const int memory_max_len,
const int* prefix_prompt_lengths,
const int max_prefix_prompt_length,
const int max_input_len,
const int* total_padding_tokens,
const int step,
const float q_scaling,
const int relative_attention_bias_stride,
const T* linear_bias_slopes,
const bool* masked_tokens,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const float* qkv_scale_out,
const float* attention_out_scale,
const int int8_mode,
const float* attention_kv_scale,
cudaStream_t stream)
{
using DataType = typename SATypeConverter<T>::Type;
// Prepare the parameters.
Masked_multihead_attention_params<DataType> params;
memset(&params, 0, sizeof(params));
// int hidden_units = head_num * size_per_head;
if (qkv_bias != nullptr) {
params.q_bias = reinterpret_cast<const DataType*>(qkv_bias);
params.k_bias = reinterpret_cast<const DataType*>(qkv_bias) + head_num * size_per_head;
params.v_bias = reinterpret_cast<const DataType*>(qkv_bias) + (head_num + kv_head_num) * size_per_head;
}
else {
params.q_bias = nullptr;
params.k_bias = nullptr;
params.v_bias = nullptr;
}
// Set the output buffer.
params.out = reinterpret_cast<DataType*>(context_buf);
// Set the input buffers.
// [B, nH + kvH, D]
params.q = reinterpret_cast<const DataType*>(qkv_buf);
params.k = reinterpret_cast<const DataType*>(qkv_buf) + head_num * size_per_head;
params.v = reinterpret_cast<const DataType*>(qkv_buf) + (head_num + kv_head_num) * size_per_head;
params.stride = (head_num + 2 * kv_head_num) * size_per_head;
params.finished = const_cast<bool*>(finished);
FT_CHECK(k_cache_per_sample && v_cache_per_sample);
params.k_cache_per_sample = reinterpret_cast<DataType**>(k_cache_per_sample);
params.v_cache_per_sample = reinterpret_cast<DataType**>(v_cache_per_sample);
params.kv_cache_per_sample_offset = kv_cache_per_sample_offset;
params.batch_size = inference_batch_size;
params.beam_width = beam_width;
params.memory_max_len = memory_max_len;
params.prefix_prompt_lengths = prefix_prompt_lengths;
params.max_prefix_prompt_length = max_prefix_prompt_length;
params.length_per_sample = sequence_lengths; // max_input_length + current output length
// timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation
params.timestep = step + max_prefix_prompt_length - 1;
params.num_heads = head_num;
params.num_kv_heads = kv_head_num;
params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_embedding_base = rotary_embedding_base;
params.max_position_embeddings = max_position_embeddings;
params.use_dynamic_ntk = use_dynamic_ntk;
params.use_logn_attn = use_logn_attn;
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling);
params.total_padding_tokens = total_padding_tokens;
if (relative_attention_bias != nullptr) {
params.relative_attention_bias = reinterpret_cast<const DataType*>(relative_attention_bias);
}
params.relative_attention_bias_stride = relative_attention_bias_stride;
params.masked_tokens = masked_tokens;
// The slope of linear position bias per head, e.g., ALiBi.
if (linear_bias_slopes != nullptr) {
params.linear_bias_slopes = reinterpret_cast<const DataType*>(linear_bias_slopes);
}
params.max_input_length = max_input_len;
params.int8_mode = int8_mode;
if (int8_mode & QuantPolicy::kCacheKVInt8) {
params.attention_k_scale = attention_kv_scale[0];
params.attention_k_zp = attention_kv_scale[1];
params.attention_v_scale = attention_kv_scale[2];
params.attention_v_zp = attention_kv_scale[3];
}
PUSH_RANGE("scaled dot-product fusion");
masked_multihead_attention(params, stream);
POP_RANGE;
}
template<typename T>
void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size, int key_len, int max_memory_len)
void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
......@@ -177,6 +44,9 @@ void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size, int ke
context_buf_ =
reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));
workspace_ = (float*)allocator_->reMalloc(
workspace_, sizeof(float) * batch_size * local_head_num_ * kMaxSplitK * (size_per_head_ + 2));
is_allocate_buffer_ = true;
}
......@@ -215,79 +85,135 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
const T* input_query_data = input_tensors->getPtr<T>("input_query");
const int* sequence_lengths_data = input_tensors->getPtr<int>("sequence_lengths");
const int* total_padding_len = input_tensors->getPtr<int>("total_padding_tokens");
const bool* finished_data = input_tensors->getPtr<bool>("finished", nullptr);
const bool* masked_tokens_data = input_tensors->getPtr<bool>("masked_tokens", nullptr);
const int* cache_indir = input_tensors->getPtr<int>("cache_indirection", nullptr);
const bool* finished_data = input_tensors->getPtr<bool>("finished");
const int sum_seq_len = input_tensors->getVal<int>("sum_seq_len");
const int max_seq_len = input_tensors->getVal<int>("max_seq_len");
T* hidden_features_data = output_tensors->getPtr<T>("attention_output");
T** key_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
T** value_cache_ptrs = output_tensors->getPtr<T*>("value_cache");
const int layer_id = input_tensors->getVal<int>("layer_id");
int* cu_block_counts = input_tensors->at("cu_block_counts").getPtr<int>();
const int max_seq_len = input_tensors->getVal<int>("max_seq_len");
const int step = input_tensors->getVal<int>("step");
const int layer_id = input_tensors->getVal<int>("layer_id");
const int step_1 = step - 1;
const int step = input_tensors->getVal<int>("step");
// const int step_1 = step - 1;
const int batch_size = input_tensors->at("input_query").shape[0];
const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1;
allocateBuffer(batch_size, step, max_seq_len);
PUSH_RANGE("qkv_gemm");
linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
POP_RANGE;
const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
const int memory_len = max_seq_len;
fusedQKV_masked_attention_dispatch<T>(
qkv_buf_,
weights->qkv.bias, // query_weight.bias,
nullptr, // relative_attention_bias,
nullptr,
nullptr,
key_cache_ptrs,
value_cache_ptrs,
kv_cache_layer_offset,
cache_indir,
context_buf_,
finished_data,
sequence_lengths_data, // NOTE: current seq len including padding (fixed after meeting the finished id)
batch_size,
batch_size,
beam_width,
local_head_num_,
local_kv_head_num_,
size_per_head_,
params_.rotray_embedding_dim,
params_.rotary_embedding_base,
params_.max_position_embeddings,
params_.use_dynamic_ntk,
params_.use_logn_attn,
memory_len,
nullptr, // prefix_prompt_lengths
0, // max_prefix_prompt_length
0, // max_input_length, not used w/o linear_bias_slopes
input_tensors->getPtr<int>("total_padding_tokens", nullptr),
step,
1.f, // q_scaling
0, // relative_attention_bias_stride
nullptr, // linear_bias_slopes
nullptr, // masked_tokens_data,
nullptr, // ia3_tasks
nullptr, // ia3_key_weights
nullptr, // ia3_value_weights
nullptr, // qkv_scale_out
nullptr, // attention_out_scale
quant_policy_, // int8_mode
weights->past_kv_scale.data(), // attention kv scale
stream_);
sync_check_cuda_error();
linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
const float* rope_theta = input_tensors->getPtr<const float>("rope_theta", nullptr);
allocateBuffer(batch_size);
// for (int i = 0; i < batch_size; ++i) {
// if (gSequenceIds(i) == 1) {
// Compare((T*)input_query_data + hidden_units_ * i,
// hidden_units_,
// Concat("query", gSequenceIds(i), seqlens[i], layer_id),
// compare_mode,
// stream_);
// }
// }
{
NvtxScope scope("qkv_gemm");
linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
}
// if (layer_id == 0) {
// Compare(qkv_buf_, batch_size * 3 * hidden_units_, Concat("qkv_buf", step, layer_id), kCmpRead, stream_);
// }
const auto layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_;
// const int memory_len = max_seq_len;
DecoderMultiHeadAttentionParams<T> params{};
params.out = context_buf_;
params.q = qkv_buf_;
params.k = params.q + local_head_num_ * size_per_head_;
params.v = params.k + local_kv_head_num_ * size_per_head_;
params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;
params.q_bias = weights->qkv.bias;
params.k_bias = params.q_bias + local_head_num_ * size_per_head_;
params.v_bias = params.k_bias + local_kv_head_num_ * size_per_head_;
params.batch_size = batch_size;
params.cu_block_cnts = cu_block_counts;
params.k_cache_block_ptrs = (void**)key_cache_ptrs;
params.v_cache_block_ptrs = (void**)value_cache_ptrs;
params.kv_cache_block_size = kv_cache_block_len_;
params.finished = finished_data;
params.per_sample_length = sequence_lengths_data;
params.rope_theta = rope_theta;
params.layer_offset = layer_offset;
params.num_heads = local_head_num_;
params.num_kv_heads = local_kv_head_num_;
params.size_per_head = size_per_head_;
params.inv_sqrt_dh = 1.f / std::sqrt((float)params.size_per_head);
params.rotary_embedding_dim = size_per_head_;
params.rotary_embedding_base = params_.rotary_embedding_base;
params.max_position_embeddings = params_.max_position_embeddings;
// params.use_dynamic_ntk = params_.use_dynamic_ntk;
params.use_logn_attn = params_.use_logn_attn;
params.partial_O = workspace_;
params.partial_M = params.partial_O + batch_size * local_head_num_ * kMaxSplitK * size_per_head_;
params.partial_L = params.partial_M + batch_size * local_head_num_ * kMaxSplitK;
// avg_batch_size = sum_seq_len / max_seq_len
// max_split_k = kMaxSplitK / avg_batch_size
// max_split_k' = min(max_split_k, max_seq_lens / kSliceLen)
const float avg_batch_size = max_seq_len ? (float)sum_seq_len / max_seq_len : 1;
FT_CHECK(avg_batch_size >= 1.f);
const int max_split_k = std::max(1, (int)std::ceil(kMaxSplitK / avg_batch_size));
// if (layer_id == 0) {
// TM_LOG_INFO("avg_batch_size = %.1f, max_split_k = %d", avg_batch_size, max_split_k);
// }
params.max_split_k = max_split_k;
params.max_seq_len = max_seq_len;
params.arch = arch_;
params.stream = stream_;
params.quant_policy = quant_policy_;
std::copy(weights->past_kv_scale.begin(), weights->past_kv_scale.end(), std::begin(params.kv_quant_params));
{
NvtxScope scope("decoder_multihead_attention");
DispatchDecoderMultiheadAttention<T>(params);
}
// for (int i = 0; i < batch_size; ++i) {
// if (gSequenceIds(i) == 1) {
// Compare((T*)context_buf_ + hidden_units_ * i,
// hidden_units_,
// Concat("context_buf", gSequenceIds(i), seqlens[i], layer_id),
// compare_mode,
// stream_);
// }
// }
// if (layer_id == 0) {
// Compare(context_buf_, batch_size * hidden_units_, Concat("context_buf", step, layer_id), kCmpRead, stream_);
// }
{
NvtxScope scope("o_gemm");
linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
}
if (tensor_para_.world_size_ > 1) {
NcclGuard nccl_guard(tensor_para_, stream_);
......
......@@ -24,6 +24,7 @@
#include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/nccl_utils.h"
namespace turbomind {
......@@ -32,7 +33,7 @@ template<typename T>
class LlamaDecoderSelfAttentionLayer {
public:
void freeBuffer();
void allocateBuffer(size_t batch_size, int key_len, int max_memory_len);
void allocateBuffer(size_t batch_size);
LlamaDecoderSelfAttentionLayer(size_t head_num,
size_t kv_head_num,
......@@ -43,6 +44,7 @@ public:
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
int cache_block_seq_len,
int quant_policy):
head_num_(head_num),
kv_head_num_(kv_head_num),
......@@ -56,9 +58,11 @@ public:
stream_(stream),
linear_(cublas_wrapper, stream),
allocator_(allocator),
kv_cache_block_len_(cache_block_seq_len),
is_free_buffer_after_forward_(is_free_buffer_after_forward),
quant_policy_(quant_policy)
{
arch_ = getSMVersion();
}
~LlamaDecoderSelfAttentionLayer()
......@@ -76,6 +80,7 @@ private:
const size_t local_head_num_;
const size_t local_kv_head_num_;
const size_t local_hidden_units_;
const size_t kv_cache_block_len_;
const bool is_free_buffer_after_forward_;
const int quant_policy_;
......@@ -90,7 +95,11 @@ private:
T* qkv_buf_ = nullptr;
T* context_buf_ = nullptr;
static constexpr int kMaxSplitK = 16; // must be <= WARP_SIZE
float* workspace_ = nullptr;
bool is_allocate_buffer_{};
int arch_{};
};
} // namespace turbomind
......@@ -20,6 +20,7 @@
#include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/kernels/activation_kernels.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/nvtx_utils.h"
// #include <glog/logging.h>
......@@ -46,6 +47,7 @@ void LlamaFfnLayer<T>::freeBuffer()
template<typename T>
void LlamaFfnLayer<T>::activation(int num_token)
{
NvtxScope scope("activation");
invokeGenericActivation<SiluActivation>(gating_buf_,
(const T*)nullptr, // bias
inter_buf_,
......@@ -76,6 +78,8 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
* \param ffn_output [token_num, hidden_dimension]
*/
NvtxScope scope("ffn");
const size_t num_token = input_tensors->at("ffn_input").shape[0];
// LOG(WARNING);
......@@ -84,24 +88,28 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
const T* ffn_input_data = input_tensors->at("ffn_input").getPtr<T>();
T* ffn_output_data = output_tensors->at("ffn_output").getPtr<T>();
PUSH_RANGE("ffn");
if (weights->fused_gating_intermediate.kernel) {
NvtxScope scope("fused_silu_ffn");
linear_.forward(
gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
}
else {
// w1(x)
linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
// w3(x)
linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
{ // w1(x)
NvtxScope scope("w1");
linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
}
{ // w3(x)
NvtxScope scope("w3");
linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
}
// silu(w1(x)) * w3(x)
activation(num_token);
}
// w2(x)
linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
POP_RANGE;
{ // w2(x)
NvtxScope scope("w2");
linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
}
if (tensor_para_.world_size_ > 1) {
NcclGuard nccl_guard(tensor_para_, stream_);
......
......@@ -28,14 +28,15 @@
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaWeight.h"
#include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include <functional>
#include <memory>
#include <sstream>
#include <stdexcept>
namespace turbomind {
......@@ -54,7 +55,8 @@ LlamaV2<T>::LlamaV2(size_t head_num,
int step_length,
int start_id,
int end_id,
int cache_max_entry_count,
float cache_max_block_count,
int cache_block_seq_len,
int cache_chunk_size,
int quant_policy,
bool use_context_fmha,
......@@ -71,12 +73,14 @@ LlamaV2<T>::LlamaV2(size_t head_num,
inter_size_(inter_size),
num_layer_(num_layer),
vocab_size_(vocab_size),
attn_params_(attn_params),
vocab_size_padded_(vocab_size),
rmsnorm_eps_(norm_eps),
start_id_(start_id),
end_id_(end_id),
hidden_units_(head_num * size_per_head),
local_head_num_(head_num / tensor_para.world_size_),
local_kv_head_num_(head_num / tensor_para.world_size_),
weights_(weights),
tensor_para_(tensor_para),
stream_(stream),
......@@ -86,7 +90,6 @@ LlamaV2<T>::LlamaV2(size_t head_num,
cuda_device_prop_(cuda_device_prop),
debug_(isDebug()),
step_length_(step_length),
batch_(max_batch_size, max_context_token_num, session_len, this),
shared_state_(shared_state)
{
......@@ -110,25 +113,38 @@ LlamaV2<T>::LlamaV2(size_t head_num,
const size_t local_kv_head_num = kv_head_num / tensor_para.world_size_;
kv_cache_mgr_ = std::make_unique<LlamaCacheManager>(num_layer_,
local_kv_head_num,
size_per_head_,
session_len,
elem_bits,
cache_max_entry_count,
cache_chunk_size,
tensor_para.rank_,
allocator);
initialize(attn_params, kv_head_num, use_context_fmha, quant_policy);
start();
auto sequence_manager = std::make_unique<SequenceManager>(num_layer,
local_kv_head_num,
size_per_head_,
cache_block_seq_len,
cache_max_block_count,
cache_chunk_size,
elem_bits,
tensor_para_.rank_,
allocator);
const size_t max_session_len = sequence_manager->max_block_count() * cache_block_seq_len;
if (max_session_len < session_len) {
if (tensor_para.rank_ == 0) {
TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.",
session_len,
max_session_len);
}
session_len = max_session_len;
}
batch_ = std::make_unique<LlamaBatch<T>>(
max_batch_size, max_context_token_num, session_len, std::move(sequence_manager), this);
initialize(attn_params, kv_head_num, use_context_fmha, cache_block_seq_len, quant_policy);
/// TODO: decouple Llama model and batch inference
batch_->Start();
}
template<typename T>
LlamaV2<T>::~LlamaV2()
{
shared_state_->request_queue.close();
internal_thread_.join();
delete decoder_;
delete dynamic_decode_layer_;
delete context_decoder_;
......@@ -138,6 +154,7 @@ template<typename T>
void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
bool use_context_fmha,
int cache_block_seq_len,
int quant_policy)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
......@@ -155,6 +172,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
allocator_,
is_free_buffer_after_forward_,
use_context_fmha,
cache_block_seq_len,
quant_policy);
decoder_ = new LlamaDecoder<T>(head_num_,
......@@ -169,6 +187,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
cublas_wrapper_,
allocator_,
is_free_buffer_after_forward_,
cache_block_seq_len,
quant_policy);
dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
......@@ -184,6 +203,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
template<typename T>
void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step)
{
NvtxScope scope("embeddingLookup");
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
// ! This kernel can't be used in context decoding
invokeEmbeddingLookupPosEncodingPadCount(embeddings,
......@@ -202,20 +222,23 @@ void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba
}
template<typename T>
void LlamaV2<T>::contextDecode(T* deocder_output,
uintptr_t* k_cache_ptr,
uintptr_t* v_cache_ptr,
T* context_decoder_input_buf,
T* context_decoder_output_buf,
const int* input_ids,
const int* input_length,
const int* history_length,
const int* context_length,
size_t token_num,
size_t max_input_len,
size_t max_context_len,
size_t session_len,
size_t batch_size)
void LlamaV2<T>::contextDecode(T* deocder_output,
uintptr_t* k_cache_ptr,
uintptr_t* v_cache_ptr,
void** tmp_k_ptrs,
void** tmp_v_ptrs,
T* context_decoder_input_buf,
T* context_decoder_output_buf,
const int* input_ids,
const int* input_length,
const int* context_length,
const int* cu_block_counts,
const float* rope_theta,
size_t token_num,
size_t max_input_len,
size_t max_context_len,
size_t session_len,
size_t batch_size)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
......@@ -248,17 +271,19 @@ void LlamaV2<T>::contextDecode(T* deocder_output,
{"decoder_input", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_input_buf}},
{"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
{"input_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, input_length}},
{"history_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, history_length}},
{"context_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, context_length}},
{"max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_q_len}},
{"max_kv_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_kv_len}},
{"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
};
{"rope_theta", {MEMORY_GPU, TYPE_FP32, {hidden_units_}, rope_theta}},
{"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {batch_size}, cu_block_counts}}};
std::unordered_map<std::string, Tensor> decoder_output_tensors{
{"decoder_output", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_output_buf}},
{"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_cache_ptr}},
{"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_cache_ptr}},
{"tmp_k", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_k_ptrs}},
{"tmp_v", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_v_ptrs}},
{"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, deocder_output}}};
context_decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &weights_->decoder_layer_weights);
......@@ -269,32 +294,35 @@ void LlamaV2<T>::contextDecode(T* deocder_output,
}
template<typename T>
void LlamaV2<T>::decoderForward(T* decoder_output,
uintptr_t* k_cache_ptr,
uintptr_t* v_cache_ptr,
T* decoder_input,
const int* sequence_length,
const int* total_padding_count,
bool* finished,
int step,
int ite,
size_t session_len,
size_t batch_size)
void LlamaV2<T>::decoderForward(T* decoder_output,
uintptr_t* k_cache_ptr,
uintptr_t* v_cache_ptr,
T* decoder_input,
const int* sequence_length,
const bool* finished,
const int* cu_block_counts,
const float* rope_theta,
int step,
int ite,
int sum_seq_len,
int max_seq_len,
size_t batch_size)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
const int max_seq_len = session_len;
const auto dtype = getTensorType<T>();
const auto dtype = getTensorType<T>();
// max_input_length is not used w/o linear_bias_slopes
// sequence_lengths_ will be incremented in dynamic decode
std::unordered_map<std::string, Tensor> decoder_input_tensors{
{"decoder_input", {MEMORY_GPU, dtype, {batch_size, hidden_units_}, decoder_input}},
{"sequence_lengths", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
{"total_padding_tokens", {MEMORY_GPU, TYPE_INT32, {batch_size}, total_padding_count}},
{"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {batch_size}, cu_block_counts}},
{"sum_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &sum_seq_len}},
{"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
{"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
{"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
{"rope_theta", {MEMORY_GPU, TYPE_FP32, {batch_size}, rope_theta}},
{"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}},
{"ite", {MEMORY_CPU, TYPE_INT32, {1}, &ite}},
};
......@@ -312,6 +340,7 @@ void LlamaV2<T>::decoderForward(T* decoder_output,
template<typename T>
void LlamaV2<T>::postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size)
{
NvtxScope scope("postDecodeEmbedding");
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
cudaDataType_t data_type = getCudaDataType<T>();
float alpha = 1.f;
......@@ -389,6 +418,7 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
size_t token_ids_len,
size_t batch_size)
{
NvtxScope scope("dynamicDecode");
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
int local_batch_size = (int)batch_size;
......@@ -432,83 +462,6 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);
}
template<typename T>
void LlamaV2<T>::internalThreadEntry(int device_id)
{
TM_LOG_INFO("[internalThreadEntry] %d", (int)tensor_para_.rank_);
check_cuda_error(cudaSetDevice(device_id));
auto& request_queue = shared_state_->request_queue;
auto& infer_requests = shared_state_->infer_requests;
auto& stop_requests = shared_state_->stop_requests;
while (1) {
if (tensor_para_.rank_ == 0) {
const int free_slot_count = batch_.maxSize() - batch_.size() + batch_.finishedCount();
const bool is_empty = free_slot_count == batch_.maxSize();
request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty);
// request queue was closed
// and there are no unprocessed requests in the queue
if (is_empty && infer_requests.empty() && stop_requests.empty()) {
// rank 0 sets flag
shared_state_->should_stop = true;
}
batch_.verifyRequests(stop_requests, infer_requests);
}
// wait while rank-0 is dequeueing
shared_state_->barrier->wait();
// exit if job is done
if (shared_state_->should_stop) {
return;
}
bool modified = false;
if (!(batch_.finishedCount() == 0 && stop_requests.empty() && infer_requests.empty())) {
batch_.handleStopRequests(stop_requests);
batch_.synchronize();
modified = true;
}
const int infer_request_count = infer_requests.size();
if (!infer_requests.empty()) {
batch_.initialize(infer_requests); // reinitialize when new requests come, possible buffer allocation
batch_.contextDecode();
modified = true;
}
// wait while shared stop/infer_requests is being used
shared_state_->barrier->wait();
if (batch_.size()) {
if (modified) {
batch_.initializeGeneration();
batch_.initializeSampling(infer_request_count);
}
for (int i = 0; i < step_length_; ++i) {
if (!batch_.generate()) {
break;
}
}
batch_.finish();
}
}
}
template<typename T>
void LlamaV2<T>::start()
{
int device_id = -1;
check_cuda_error(cudaGetDevice(&device_id));
internal_thread_ = std::thread(&LlamaV2<T>::internalThreadEntry, this, device_id);
}
static inline Tensor slice(const Tensor& tensor, int index)
{
auto shape = tensor.shape;
......@@ -591,15 +544,25 @@ void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>* outputs,
bool has_error = 0;
if (rank == 0) {
TM_LOG_INFO("[forward] Enqueue requests");
std::vector<uint64_t> ids;
for (const auto& r : requests) {
ids.push_back(r->id);
}
auto futures = shared_state_->request_queue.enqueue(std::move(requests));
FT_CHECK_WITH_INFO(ids.size() == futures.size(), "check failed");
TM_LOG_INFO("[forward] Wait for requests to complete ...");
for (auto& f : futures) {
auto ec = f.get();
for (int i = 0; i < futures.size(); ++i) {
auto ec = futures[i].get();
error_codes.push_back(ec);
if (ec) {
has_error = true;
}
TM_LOG_INFO("[forward] Request complete for %ld, ec = %d", (long)ids[i], (int)ec);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment