Unverified Commit 7f943a26 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Unify prefill & decode passes (#775)

* Unify prefill and decode passes

* dynamic split-fuse

* refactor

* correct input count calculation

* remove unused

* lint

* lint

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build
parent 2ba90822
...@@ -48,7 +48,9 @@ __global__ void getPaddingOffsetAndCuSeqLensKernel(size_t* h_valid_word_num, ...@@ -48,7 +48,9 @@ __global__ void getPaddingOffsetAndCuSeqLensKernel(size_t* h_valid_word_num,
if (calculate_cu_seqlens) { if (calculate_cu_seqlens) {
cu_seqlens[batch_size] = total_seq_len; cu_seqlens[batch_size] = total_seq_len;
} }
h_valid_word_num[0] = (size_t)total_seq_len; if (h_valid_word_num) {
h_valid_word_num[0] = (size_t)total_seq_len;
}
} }
void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num, void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num,
...@@ -60,15 +62,19 @@ void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num, ...@@ -60,15 +62,19 @@ void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num,
const int max_seq_len, const int max_seq_len,
cudaStream_t stream) cudaStream_t stream)
{ {
h_pinned_token_num[0] = 0; if (h_pinned_token_num) {
h_pinned_token_num[0] = 0;
}
getPaddingOffsetAndCuSeqLensKernel<<<1, 1, 0, stream>>>( getPaddingOffsetAndCuSeqLensKernel<<<1, 1, 0, stream>>>(
h_pinned_token_num, tmp_mask_offset, cu_seqlens, sequence_lengths, batch_size, max_seq_len); h_pinned_token_num, tmp_mask_offset, cu_seqlens, sequence_lengths, batch_size, max_seq_len);
if (h_pinned_token_num) {
#ifdef _MSC_VER #ifdef _MSC_VER
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
#else #else
while (((volatile size_t*)h_pinned_token_num)[0] == 0) {}; while (((volatile size_t*)h_pinned_token_num)[0] == 0) {};
#endif #endif
h_token_num[0] = h_pinned_token_num[0]; h_token_num[0] = h_pinned_token_num[0];
}
sync_check_cuda_error(); sync_check_cuda_error();
} }
......
...@@ -20,13 +20,11 @@ struct DecoderMultiHeadAttentionParams { ...@@ -20,13 +20,11 @@ struct DecoderMultiHeadAttentionParams {
T* __restrict__ v_bias; T* __restrict__ v_bias;
// sequence-level buffers // sequence-level buffers
const int* __restrict__ per_sample_length; const int* __restrict__ context_length;
const bool* __restrict__ finished; const bool* __restrict__ finished;
const float* __restrict__ rope_theta; const float* __restrict__ rope_theta;
// kv cache // kv cache
void** __restrict__ per_sample_k_cache; // [H, S, D]
void** __restrict__ per_sample_v_cache; // [H, S, D]
size_t layer_offset; size_t layer_offset;
/// cache layout M,[N,H,x,D] /// cache layout M,[N,H,x,D]
......
...@@ -145,7 +145,7 @@ struct DecoderMultiHeadAttentionKernel { ...@@ -145,7 +145,7 @@ struct DecoderMultiHeadAttentionKernel {
kv_head_idx_ = head_idx_ / gqa_group_size; kv_head_idx_ = head_idx_ / gqa_group_size;
is_gqa_leader_ = head_idx_ % gqa_group_size == 0; is_gqa_leader_ = head_idx_ % gqa_group_size == 0;
timestep_ = params_.per_sample_length[batch_idx_]; timestep_ = params_.context_length[batch_idx_] - 1;
if (kSplitK && params.max_split_k > 1) { if (kSplitK && params.max_split_k > 1) {
const int slice_count = (timestep_ + kSliceLen - 1) / kSliceLen; const int slice_count = (timestep_ + kSliceLen - 1) / kSliceLen;
...@@ -815,7 +815,7 @@ struct DecoderMultiHeadAttentionKernel { ...@@ -815,7 +815,7 @@ struct DecoderMultiHeadAttentionKernel {
{ {
const int batch_idx = get_batch_idx(); const int batch_idx = get_batch_idx();
const int head_idx = get_head_idx(); const int head_idx = get_head_idx();
const int timestep = params.per_sample_length[batch_idx]; const int timestep = params.context_length[batch_idx] - 1;
const int max_split_k = params.max_split_k; const int max_split_k = params.max_split_k;
const int slice_count = get_slice_count(timestep); const int slice_count = get_slice_count(timestep);
const int slice_per_split = (slice_count + max_split_k - 1) / max_split_k; const int slice_per_split = (slice_count + max_split_k - 1) / max_split_k;
......
...@@ -53,7 +53,7 @@ void TestBlocks(thrust::universal_vector<half>& linear, // linear data ...@@ -53,7 +53,7 @@ void TestBlocks(thrust::universal_vector<half>& linear, // linear data
std::mt19937 g(rd()); std::mt19937 g(rd());
std::shuffle(idxs.begin(), idxs.end(), g); std::shuffle(idxs.begin(), idxs.end(), g);
for (int i = 0; i < idxs.size(); ++i) { for (size_t i = 0; i < idxs.size(); ++i) {
ptrs[i] = blocks.data().get() + idxs[i] * head_num * block_size * head_dim; ptrs[i] = blocks.data().get() + idxs[i] * head_num * block_size * head_dim;
} }
...@@ -115,8 +115,8 @@ int main(int argc, char* argv[]) ...@@ -115,8 +115,8 @@ int main(int argc, char* argv[])
constexpr int KvHeadNum = 32; constexpr int KvHeadNum = 32;
constexpr int kBatchSize = 1; constexpr int kBatchSize = 1;
// constexpr int kContextLen = 7306; // constexpr int kContextLen = 7306;
constexpr int kContextLen = 1024; constexpr int kSequenceLen = 1024;
constexpr int kSequenceLen = kContextLen + 1; constexpr int kContextLen = kSequenceLen + 1;
constexpr int kBlockSz = 128; constexpr int kBlockSz = 128;
constexpr int kTestIter = 10; constexpr int kTestIter = 10;
constexpr int kMaxSplitK = 1; constexpr int kMaxSplitK = 1;
...@@ -126,9 +126,10 @@ int main(int argc, char* argv[]) ...@@ -126,9 +126,10 @@ int main(int argc, char* argv[])
thrust::universal_vector<half> output(kBatchSize * kHeadNum * kHeadDim); thrust::universal_vector<half> output(kBatchSize * kHeadNum * kHeadDim);
thrust::universal_vector<half> qkv(kBatchSize * (kHeadNum + KvHeadNum * 2) * kHeadDim); thrust::universal_vector<half> qkv(kBatchSize * (kHeadNum + KvHeadNum * 2) * kHeadDim);
thrust::universal_vector<bool> finished(kBatchSize); thrust::universal_vector<bool> finished(kBatchSize);
thrust::universal_vector<half> k_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim); thrust::universal_vector<half> k_cache(kBatchSize * kContextLen * KvHeadNum * kHeadDim);
thrust::universal_vector<half> v_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim); thrust::universal_vector<half> v_cache(kBatchSize * kContextLen * KvHeadNum * kHeadDim);
thrust::universal_vector<int> sequence_lengths(kBatchSize); thrust::universal_vector<int> context_length(kBatchSize);
thrust::universal_vector<int> sequence_length(kBatchSize);
thrust::universal_vector<void*> k_cache_ptrs(kBatchSize); thrust::universal_vector<void*> k_cache_ptrs(kBatchSize);
thrust::universal_vector<void*> v_cache_ptrs(kBatchSize); thrust::universal_vector<void*> v_cache_ptrs(kBatchSize);
...@@ -138,23 +139,23 @@ int main(int argc, char* argv[]) ...@@ -138,23 +139,23 @@ int main(int argc, char* argv[])
rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f); rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);
if (kContextLen) { if (kSequenceLen) {
rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim); rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);
rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim); rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);
cudaMemset2DAsync(k_cache.data().get() + kContextLen * kHeadDim, cudaMemset2DAsync(k_cache.data().get() + kSequenceLen * kHeadDim,
sizeof(half) * kSequenceLen * kHeadDim, sizeof(half) * kContextLen * kHeadDim,
0, 0,
sizeof(half) * kHeadDim, sizeof(half) * kHeadDim,
kBatchSize * KvHeadNum); kBatchSize * KvHeadNum);
if constexpr (0) { if constexpr (0) {
for (int b = 0; b < kBatchSize; ++b) { for (int b = 0; b < kBatchSize; ++b) {
for (int h = 0; h < KvHeadNum; ++h) { for (int h = 0; h < KvHeadNum; ++h) {
for (int s = 0; s < kSequenceLen; ++s) { for (int s = 0; s < kContextLen; ++s) {
for (int d = 0; d < kHeadDim; ++d) { for (int d = 0; d < kHeadDim; ++d) {
std::cout << std::setw(7) << std::setprecision(4) << std::fixed std::cout << std::setw(7) << std::setprecision(4) << std::fixed
<< (float)k_cache[b * KvHeadNum * kSequenceLen * kHeadDim << (float)k_cache[b * KvHeadNum * kContextLen * kHeadDim
+ h * kSequenceLen * kHeadDim + s * kHeadDim + d] + h * kContextLen * kHeadDim + s * kHeadDim + d]
<< " "; << " ";
} }
std::cout << "\n"; std::cout << "\n";
...@@ -166,8 +167,8 @@ int main(int argc, char* argv[]) ...@@ -166,8 +167,8 @@ int main(int argc, char* argv[])
std::exit(0); std::exit(0);
} }
cudaMemset2DAsync(v_cache.data().get() + kContextLen * kHeadDim, cudaMemset2DAsync(v_cache.data().get() + kSequenceLen * kHeadDim,
sizeof(half) * kSequenceLen * kHeadDim, sizeof(half) * kContextLen * kHeadDim,
0, 0,
sizeof(half) * kHeadDim, sizeof(half) * kHeadDim,
kBatchSize * KvHeadNum); kBatchSize * KvHeadNum);
...@@ -193,7 +194,8 @@ int main(int argc, char* argv[]) ...@@ -193,7 +194,8 @@ int main(int argc, char* argv[])
cudaDeviceSynchronize(); cudaDeviceSynchronize();
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
sequence_lengths[i] = kContextLen; sequence_length[i] = kSequenceLen;
context_length[i] = kContextLen;
k_cache_ptrs[i] = k_cache.data().get() + i * k_cache.size() / kBatchSize; k_cache_ptrs[i] = k_cache.data().get() + i * k_cache.size() / kBatchSize;
v_cache_ptrs[i] = v_cache.data().get() + i * v_cache.size() / kBatchSize; v_cache_ptrs[i] = v_cache.data().get() + i * v_cache.size() / kBatchSize;
k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize; k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;
...@@ -212,7 +214,7 @@ int main(int argc, char* argv[]) ...@@ -212,7 +214,7 @@ int main(int argc, char* argv[])
params.stride = (kHeadNum + 2 * KvHeadNum) * kHeadDim; params.stride = (kHeadNum + 2 * KvHeadNum) * kHeadDim;
params.batch_size = kBatchSize; params.batch_size = kBatchSize;
params.max_seq_len = kContextLen + 1; params.max_seq_len = kSequenceLen;
params.cu_block_cnts = cu_block_cnts.data().get(); params.cu_block_cnts = cu_block_cnts.data().get();
printf("%d %d\n", (int)k_ptrs.size(), (int)v_ptrs.size()); printf("%d %d\n", (int)k_ptrs.size(), (int)v_ptrs.size());
...@@ -220,11 +222,9 @@ int main(int argc, char* argv[]) ...@@ -220,11 +222,9 @@ int main(int argc, char* argv[])
params.v_cache_block_ptrs = (void**)v_ptrs.data().get(); params.v_cache_block_ptrs = (void**)v_ptrs.data().get();
params.kv_cache_block_size = kBlockSz; params.kv_cache_block_size = kBlockSz;
params.finished = finished.data().get(); params.finished = finished.data().get();
params.per_sample_length = sequence_lengths.data().get(); params.context_length = context_length.data().get();
params.per_sample_k_cache = k_cache_ref_ptrs.data().get(); params.layer_offset = 0;
params.per_sample_v_cache = v_cache_ref_ptrs.data().get();
params.layer_offset = 0;
params.num_heads = kHeadNum; params.num_heads = kHeadNum;
params.num_kv_heads = KvHeadNum; params.num_kv_heads = KvHeadNum;
...@@ -238,8 +238,16 @@ int main(int argc, char* argv[]) ...@@ -238,8 +238,16 @@ int main(int argc, char* argv[])
params.partial_M = partial_M.data().get(); params.partial_M = partial_M.data().get();
params.partial_O = partial_O.data().get(); params.partial_O = partial_O.data().get();
params.max_split_k = kMaxSplitK;
params.arch = 80;
for (int i = 0; i < kTestIter; ++i) { for (int i = 0; i < kTestIter; ++i) {
mmha_ft_reference(params, cudaStream_t{}); mmha_ft_reference(params,
(half**)k_cache_ref_ptrs.data().get(),
(half**)v_cache_ref_ptrs.data().get(),
sequence_length.data().get(),
kContextLen,
cudaStream_t{});
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
...@@ -249,14 +257,7 @@ int main(int argc, char* argv[]) ...@@ -249,14 +257,7 @@ int main(int argc, char* argv[])
} }
std::cout << "---------------------------------------------------\n"; std::cout << "---------------------------------------------------\n";
params.out = output.data().get(); params.out = output.data().get();
params.per_sample_k_cache = k_cache_ptrs.data().get();
params.per_sample_v_cache = v_cache_ptrs.data().get();
params.max_split_k = kMaxSplitK;
params.max_seq_len = kContextLen;
params.arch = 80;
std::vector<thrust::universal_vector<half>> outputs; std::vector<thrust::universal_vector<half>> outputs;
...@@ -271,19 +272,14 @@ int main(int argc, char* argv[]) ...@@ -271,19 +272,14 @@ int main(int argc, char* argv[])
} }
} }
thrust::universal_vector<int> seq_lens(kBatchSize);
for (auto& x : seq_lens) {
x = kContextLen + 1;
}
if (1) { if (1) {
ConvertBlocksToLinear((const half**)k_ptrs.data().get(), ConvertBlocksToLinear((const half**)k_ptrs.data().get(),
k_cache.data().get(), k_cache.data().get(),
cu_block_cnts.data().get(), cu_block_cnts.data().get(),
seq_lens.data().get(), context_length.data().get(),
0, 0,
kBlockSz, kBlockSz,
kSequenceLen, kContextLen,
KvHeadNum, KvHeadNum,
kHeadDim, kHeadDim,
kBatchSize, kBatchSize,
...@@ -291,10 +287,10 @@ int main(int argc, char* argv[]) ...@@ -291,10 +287,10 @@ int main(int argc, char* argv[])
ConvertBlocksToLinear((const half**)v_ptrs.data().get(), ConvertBlocksToLinear((const half**)v_ptrs.data().get(),
v_cache.data().get(), v_cache.data().get(),
cu_block_cnts.data().get(), cu_block_cnts.data().get(),
seq_lens.data().get(), context_length.data().get(),
0, 0,
kBlockSz, kBlockSz,
kSequenceLen, kContextLen,
KvHeadNum, KvHeadNum,
kHeadDim, kHeadDim,
kBatchSize, kBatchSize,
...@@ -316,15 +312,15 @@ int main(int argc, char* argv[]) ...@@ -316,15 +312,15 @@ int main(int argc, char* argv[])
// [H, S, D] // [H, S, D]
Compare(k_cache.data().get() + kContextLen * kHeadDim, Compare(k_cache.data().get() + kSequenceLen * kHeadDim,
k_cache_ref.data().get() + kContextLen * kHeadDim, k_cache_ref.data().get() + kSequenceLen * kHeadDim,
kSequenceLen * kHeadDim, kContextLen * kHeadDim,
kHeadDim, kHeadDim,
KvHeadNum); KvHeadNum);
Compare(v_cache.data().get() + kContextLen * kHeadDim, Compare(v_cache.data().get() + kSequenceLen * kHeadDim,
v_cache_ref.data().get() + kContextLen * kHeadDim, v_cache_ref.data().get() + kSequenceLen * kHeadDim,
kSequenceLen * kHeadDim, kContextLen * kHeadDim,
kHeadDim, kHeadDim,
KvHeadNum); KvHeadNum);
......
...@@ -182,7 +182,12 @@ struct SATypeConverter<half> { ...@@ -182,7 +182,12 @@ struct SATypeConverter<half> {
}; };
template<typename T> template<typename T>
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t st) void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p,
T** per_sample_k_cache,
T** per_sample_v_cache,
const int* sequence_length,
int max_memory_len,
cudaStream_t st)
{ {
using DataType = typename SATypeConverter<T>::Type; using DataType = typename SATypeConverter<T>::Type;
...@@ -204,18 +209,18 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t ...@@ -204,18 +209,18 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
params.stride = p.stride; params.stride = p.stride;
params.finished = (bool*)p.finished; params.finished = (bool*)p.finished;
params.k_cache_per_sample = reinterpret_cast<DataType**>(p.per_sample_k_cache); params.k_cache_per_sample = reinterpret_cast<DataType**>(per_sample_k_cache);
params.v_cache_per_sample = reinterpret_cast<DataType**>(p.per_sample_v_cache); params.v_cache_per_sample = reinterpret_cast<DataType**>(per_sample_v_cache);
params.kv_cache_per_sample_offset = p.layer_offset; params.kv_cache_per_sample_offset = p.layer_offset;
params.batch_size = p.batch_size; params.batch_size = p.batch_size;
params.beam_width = 1; params.beam_width = 1;
params.memory_max_len = p.max_seq_len; params.memory_max_len = max_memory_len;
params.prefix_prompt_lengths = 0; params.prefix_prompt_lengths = 0;
params.max_prefix_prompt_length = 0; params.max_prefix_prompt_length = 0;
params.length_per_sample = p.per_sample_length; // max_input_length + current output length params.length_per_sample = sequence_length; // max_input_length + current output length
for (int i = 0; i < p.batch_size; ++i) { for (int i = 0; i < p.batch_size; ++i) {
params.timestep = std::max(p.per_sample_length[i], params.timestep); params.timestep = std::max(sequence_length[i], params.timestep);
} }
std::cout << "timestep = " << params.timestep << "\n"; std::cout << "timestep = " << params.timestep << "\n";
...@@ -237,6 +242,11 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t ...@@ -237,6 +242,11 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
masked_multihead_attention(params, st); masked_multihead_attention(params, st);
} }
template void mmha_ft_reference(const DecoderMultiHeadAttentionParams<half>& params, cudaStream_t st); template void mmha_ft_reference(const DecoderMultiHeadAttentionParams<half>& params,
half** per_sample_k_cache,
half** per_sample_v_cache,
const int* sequence_length,
int max_memory_len,
cudaStream_t st);
} // namespace turbomind } // namespace turbomind
...@@ -33,6 +33,11 @@ private: ...@@ -33,6 +33,11 @@ private:
}; };
template<typename T> template<typename T>
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& params, cudaStream_t st); void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& params,
T** per_sample_k_cache,
T** per_sample_v_cache,
const int* sequence_length,
int max_memory_len,
cudaStream_t st);
} // namespace turbomind } // namespace turbomind
...@@ -34,10 +34,11 @@ public: ...@@ -34,10 +34,11 @@ public:
class Barrier { class Barrier {
public: public:
Barrier(unsigned count) Barrier(unsigned count): count_(count)
{ {
TM_LOG_INFO("Barrier(%d)", (int)count); if (count_ > 1) {
pthread_barrier_init(&barrier_, nullptr, count); pthread_barrier_init(&barrier_, nullptr, count);
}
} }
Barrier(const Barrier&) = delete; Barrier(const Barrier&) = delete;
...@@ -47,15 +48,20 @@ public: ...@@ -47,15 +48,20 @@ public:
void wait() void wait()
{ {
pthread_barrier_wait(&barrier_); if (count_ > 1) {
pthread_barrier_wait(&barrier_);
}
} }
~Barrier() ~Barrier()
{ {
pthread_barrier_destroy(&barrier_); if (count_ > 1) {
pthread_barrier_destroy(&barrier_);
}
} }
private: private:
const int count_;
pthread_barrier_t barrier_{}; pthread_barrier_t barrier_{};
}; };
......
...@@ -9,16 +9,13 @@ find_package(CUDAToolkit REQUIRED) ...@@ -9,16 +9,13 @@ find_package(CUDAToolkit REQUIRED)
add_library(Llama STATIC add_library(Llama STATIC
LlamaV2.cc LlamaV2.cc
LlamaBatch.cc LlamaBatch.cc
LlamaCacheManager.cc
BlockManager.cc BlockManager.cc
SequenceManager.cc SequenceManager.cc
LlamaContextDecoder.cc
LlamaContextAttentionLayer.cc
LlamaDecoderSelfAttentionLayer.cc
LlamaDecoder.cc
LlamaWeight.cc LlamaWeight.cc
LlamaDecoderLayerWeight.cc LlamaDecoderLayerWeight.cc
LlamaFfnLayer.cc LlamaFfnLayer.cc
unified_decoder.cc
unified_attention_layer.cc
llama_kernels.cu llama_kernels.cu
llama_decoder_kernels.cu llama_decoder_kernels.cu
llama_utils.cu) llama_utils.cu)
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "src/turbomind/models/llama/LlamaV2.h" #include "src/turbomind/models/llama/LlamaV2.h"
#include "src/turbomind/models/llama/Request.h" #include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h" #include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/copy.h"
#include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
...@@ -19,6 +20,7 @@ ...@@ -19,6 +20,7 @@
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <functional>
#include <iomanip> #include <iomanip>
#include <iterator> #include <iterator>
#include <mutex> #include <mutex>
...@@ -29,6 +31,28 @@ ...@@ -29,6 +31,28 @@
namespace turbomind { namespace turbomind {
void PrintDecodeTokens(
const int* token_ids, int max_seq_len, int batch_sizse, cudaStream_t stream, const std::string& msg)
{
// tokens in [S, B] layout
std::vector<int> tokens(max_seq_len * batch_sizse);
check_cuda_error(cudaMemcpyAsync(tokens.data(), token_ids, sizeof(int) * tokens.size(), cudaMemcpyDefault, stream));
check_cuda_error(cudaStreamSynchronize(stream));
printf("[%s] ", msg.c_str());
for (int j = 0; j < max_seq_len; ++j) {
printf("%5d ", j);
}
printf("\n");
for (int i = 0; i < batch_sizse; ++i) {
printf("[%s] ", msg.c_str());
for (int j = 0; j < max_seq_len; ++j) {
// std::cout << sb_tokens[j * batch_size + i] << " ";
printf("%5d ", tokens[j * batch_sizse + i]);
}
printf("\n");
}
}
void ClearState(BatchState& s) void ClearState(BatchState& s)
{ {
std::fill_n(s.requests.begin(), s.size, nullptr); std::fill_n(s.requests.begin(), s.size, nullptr);
...@@ -297,7 +321,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) ...@@ -297,7 +321,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
// ! SHARED STATE IS MODIFIED, BARRIER SYNCHRONIZATION REQUIRED // ! SHARED STATE IS MODIFIED, BARRIER SYNCHRONIZATION REQUIRED
// assign priority based on arrival time // assign priority based on arrival time
if (rank_ == 0) { if (rank_ == 0) {
r->priority = request_count_++; r->unique_id = request_count_++;
} }
// increment pointer // increment pointer
...@@ -324,7 +348,39 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) ...@@ -324,7 +348,39 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
} }
template<typename T> template<typename T>
bool LlamaBatch<T>::Initialize() void LlamaBatch<T>::AdjustMaxInputCount(GenerationState& g,
const std::vector<const Sequence*>& sequences,
const std::vector<int>& context_length)
{
int input_count = 0;
for (int i = 0; i < sequences.size(); ++i) {
input_count += context_length[i] - sequences[i]->cache_len;
}
const int batch_size = sequences.size();
input_count -= batch_size;
// min tokens per iter for satisfying max prefill iters constraint
input_count = (input_count + max_prefill_iters_ - 1) / max_prefill_iters_;
if (g.min_input_count.empty()) {
g.min_input_count.resize(max_prefill_iters_);
}
g.min_input_count.pop_front();
g.min_input_count.push_back(input_count);
/// TODO: sub-optimal when there are inactive sequences due to memory constraint
for (auto& x : g.min_input_count) {
x = std::max(x, input_count);
}
input_count = std::max(g.min_input_count.front() + batch_size, num_tokens_per_iter_);
input_count = std::min(input_count, max_context_token_num_);
// update max input count
g.max_input_count1 = input_count;
g.max_input_count2 = std::min(input_count + extra_tokens_per_iter_, max_context_token_num_);
}
template<typename T>
void LlamaBatch<T>::Initialize(GenerationState& g)
{ {
NvtxScope scope("initialize"); NvtxScope scope("initialize");
std::vector<const Sequence*> sequences; std::vector<const Sequence*> sequences;
...@@ -346,18 +402,14 @@ bool LlamaBatch<T>::Initialize() ...@@ -346,18 +402,14 @@ bool LlamaBatch<T>::Initialize()
} }
} }
// dbg(holes, active_holes);
auto process = [&](BatchState* state) { auto process = [&](BatchState* state) {
for (int i = 0; i < state->size; ++i) { for (int i = 0; i < state->size; ++i) {
if (auto& r = state->requests[i]) { if (auto& r = state->requests[i]) {
sequences.push_back(state->sequences[i]); sequences.push_back(state->sequences[i]);
status.push_back(state->sequences[i]->status); status.push_back(state->sequences[i]->status);
priorities.push_back(r->priority); priorities.push_back(r->unique_id);
context_lengths.push_back(state->h_context_length[i]); context_lengths.push_back(state->h_context_length[i]);
coords.emplace_back(state, i); coords.emplace_back(state, i);
// clear swap-in flags
state->is_swap_in[i] = 0;
} }
} }
}; };
...@@ -365,7 +417,14 @@ bool LlamaBatch<T>::Initialize() ...@@ -365,7 +417,14 @@ bool LlamaBatch<T>::Initialize()
process(state_); process(state_);
process(incoming_); process(incoming_);
auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_); auto adjust = [this, &g](const Sequences& sequences,
const std::vector<int>& context_length) -> std::pair<int, int> {
AdjustMaxInputCount(g, sequences, context_length);
return {g.max_input_count1, g.max_input_count2};
};
// TM_LOG_INFO("max_input_count %d", max_input_count);
auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_, adjust);
if (outcome.allocation || outcome.swap_in || outcome.swap_out) { if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
dbg(outcome); dbg(outcome);
...@@ -379,7 +438,7 @@ bool LlamaBatch<T>::Initialize() ...@@ -379,7 +438,7 @@ bool LlamaBatch<T>::Initialize()
if (exchange || holes || incoming_->size) { if (exchange || holes || incoming_->size) {
// put active ones first // put active ones first
auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) { auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) {
return sequences[idx]->status == Sequence::kActive; // present status return sequences[idx]->status == Sequence::kActive; // current status
}); });
// all blocks are not enough to hold a single sequence // all blocks are not enough to hold a single sequence
...@@ -387,18 +446,21 @@ bool LlamaBatch<T>::Initialize() ...@@ -387,18 +446,21 @@ bool LlamaBatch<T>::Initialize()
FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks."); FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks.");
} }
// move swap-ins to the back // move the partial seq to the back
auto swapin_beg = std::stable_partition(idxs.begin(), active_end, [&](int idx) { auto partial_beg = std::stable_partition(idxs.begin(), active_end, [&](int i) {
return status[idx] == Sequence::kActive; // past status return sequences[i]->cache_len + sequences[i]->input_length == context_lengths[i];
}); });
FT_CHECK(active_end - partial_beg <= 1);
// sort swap-ins according to missing length auto swapin_beg = std::stable_partition(idxs.begin(), partial_beg, [&](int i) {
if (swapin_beg != active_end) { return status[i] == Sequence::kActive; // past status
std::vector<int> missing_len(sequences.size()); });
for (int i = 0; i < sequences.size(); ++i) {
missing_len[i] = context_lengths[i] - sequences[i]->cache_len; // sort swap-ins according to input length
} if (swapin_beg != partial_beg) {
std::stable_sort(swapin_beg, active_end, [&](int i, int j) { return missing_len[i] < missing_len[j]; }); std::stable_sort(swapin_beg, partial_beg, [&](int i, int j) {
return sequences[i]->input_length < sequences[j]->input_length;
});
} }
// Copy sequence states to back buffer // Copy sequence states to back buffer
...@@ -406,13 +468,6 @@ bool LlamaBatch<T>::Initialize() ...@@ -406,13 +468,6 @@ bool LlamaBatch<T>::Initialize()
std::vector<std::tuple<BatchState*, BatchState*, int, int>> cpys; std::vector<std::tuple<BatchState*, BatchState*, int, int>> cpys;
for (const auto& i : idxs) { for (const auto& i : idxs) {
auto& s = *sequences[i]; auto& s = *sequences[i];
if (exchange) {
const auto& [state, idx] = coords[i];
// mark swap-ins
if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
state->is_swap_in[idx] = 1;
}
}
if (s.status == Sequence::kActive) { if (s.status == Sequence::kActive) {
++back_->active_size; ++back_->active_size;
} }
...@@ -465,11 +520,59 @@ bool LlamaBatch<T>::Initialize() ...@@ -465,11 +520,59 @@ bool LlamaBatch<T>::Initialize()
Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_); Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
} }
/// Layout of the buffers is changed, generation & sampling need to be re-initialized for correctness when there const int batch_size = state_->active_size;
/// were
// 1. swap-in or swap-out // check if the last sequence is partial
// 2. holes in the active buffer int partial = 0;
return exchange || active_holes; int partial_len = -1;
{
const int i = state_->active_size - 1;
partial = state_->sequences[i]->cache_len + state_->sequences[i]->input_length != state_->h_context_length[i];
if (partial) {
// backup full context length of partial
partial_len = state_->h_context_length[i];
// replace with partial context length
state_->h_context_length[i] = state_->sequences[i]->cache_len + state_->sequences[i]->input_length;
}
}
const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
std::vector<uint64_t> unique_ids(batch_size);
for (int i = 0; i < batch_size; ++i) {
unique_ids[i] = state_->requests[i]->unique_id;
}
// Real-time context length that will change during generation
Copy(state_->h_context_length, batch_size, context_length_buf_);
Copy(state_->h_finished, batch_size, finished_buf_);
Copy(state_->h_rope_theta, batch_size, rope_theta_);
// used for dispatching split-k decoding kernels
const int sum_seq_len =
std::accumulate(state_->h_context_length, state_->h_context_length + batch_size, -batch_size);
const int max_seq_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size) - 1;
// TM_LOG_INFO(
// "[init] batch_size = %d, max_ctx_len = %d, partial = %d", (int)batch_size, (int)max_context_len, partial);
bool skip_init_sampling = std::equal(g.unique_ids.begin(), //
g.unique_ids.end() - g.partial,
unique_ids.begin(),
unique_ids.end() - partial);
g.sum_seq_len = sum_seq_len;
g.max_seq_len = max_seq_len;
g.partial = partial;
g.partial_context_legnth = partial_len;
g.unique_ids = std::move(unique_ids);
g.finished_count = 0;
if (!skip_init_sampling) {
g.max_init_ctx_len = max_context_len;
g.step = max_context_len;
InitializeSampling(g);
}
} }
template<typename T> template<typename T>
...@@ -528,7 +631,6 @@ void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchSta ...@@ -528,7 +631,6 @@ void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchSta
d->h_rope_theta[di] = s->h_rope_theta[si]; d->h_rope_theta[di] = s->h_rope_theta[si];
d->seq_len_limit[di] = s->seq_len_limit[si]; d->seq_len_limit[di] = s->seq_len_limit[si];
d->sequences[di] = s->sequences[si]; d->sequences[di] = s->sequences[si];
d->is_swap_in[di] = s->is_swap_in[si];
d->requests[di] = s->requests[si]; d->requests[di] = s->requests[si];
} }
} }
...@@ -564,9 +666,10 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len) ...@@ -564,9 +666,10 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
decoder_input_buf_ = (T*)allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units, false); decoder_input_buf_ = (T*)allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units, false);
decoder_output_buf_ = (T*)allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units, false); decoder_output_buf_ = (T*)allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units, false);
input_ids_buf_ = (int*)allocator_->reMalloc(input_ids_buf_, sizeof(int) * batchxbeam * session_len, true); input_ids_buf_ = (int*)allocator_->reMalloc(input_ids_buf_, sizeof(int) * batchxbeam * session_len, true);
input_length_buf_ = (int*)allocator_->reMalloc(input_length_buf_, sizeof(int) * batchxbeam); input_length_buf_ = (int*)allocator_->reMalloc(input_length_buf_, sizeof(int) * batchxbeam);
context_length_buf_ = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam); context_length_buf_ = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam);
init_context_length_ = (int*)allocator_->reMalloc(init_context_length_, sizeof(int) * batchxbeam);
sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false); sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
...@@ -582,10 +685,6 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len) ...@@ -582,10 +685,6 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false); finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false);
seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false); seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false);
request_output_ids_ptrs_ = (int**)allocator_->reMalloc(request_output_ids_ptrs_, sizeof(int*) * batch_size, true);
request_output_ids_lens_ = (int*)allocator_->reMalloc(request_output_ids_lens_, sizeof(int) * batch_size, true);
request_seqlen_ptrs_ = (int**)allocator_->reMalloc(request_seqlen_ptrs_, sizeof(int*) * batch_size, true);
rope_theta_ = (float*)allocator_->reMalloc(rope_theta_, sizeof(float) * batch_size, false); rope_theta_ = (float*)allocator_->reMalloc(rope_theta_, sizeof(float) * batch_size, false);
is_allocate_buffer_ = true; is_allocate_buffer_ = true;
...@@ -664,13 +763,6 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size) ...@@ -664,13 +763,6 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
h_seq_limit_len_ = h_seq_limit_len_ =
(uint32_t*)allocator_->reMalloc(h_seq_limit_len_, sizeof(uint32_t) * max_batch_size, false, true); (uint32_t*)allocator_->reMalloc(h_seq_limit_len_, sizeof(uint32_t) * max_batch_size, false, true);
h_request_output_ids_ptrs_ =
(int**)allocator_->reMalloc(h_request_output_ids_ptrs_, sizeof(int*) * max_batch_size, true, true);
h_request_output_ids_lens_ =
(int*)allocator_->reMalloc(h_request_output_ids_lens_, sizeof(int) * max_batch_size, true, true);
h_request_seqlen_ptrs_ =
(int**)allocator_->reMalloc(h_request_seqlen_ptrs_, sizeof(int*) * max_batch_size, true, true);
h_output_ids_ = h_output_ids_ =
(int*)allocator_->reMalloc(h_output_ids_, sizeof(int) * max_batch_size * session_len_, false, true); (int*)allocator_->reMalloc(h_output_ids_, sizeof(int) * max_batch_size * session_len_, false, true);
} }
...@@ -698,6 +790,7 @@ void LlamaBatch<T>::FreeBuffer() ...@@ -698,6 +790,7 @@ void LlamaBatch<T>::FreeBuffer()
allocator_->free((void**)&input_ids_buf_); allocator_->free((void**)&input_ids_buf_);
allocator_->free((void**)&input_length_buf_); allocator_->free((void**)&input_length_buf_);
allocator_->free((void**)&context_length_buf_); allocator_->free((void**)&context_length_buf_);
allocator_->free((void**)&init_context_length_);
allocator_->free((void**)&sequence_lengths_); allocator_->free((void**)&sequence_lengths_);
...@@ -723,10 +816,6 @@ void LlamaBatch<T>::FreeBuffer() ...@@ -723,10 +816,6 @@ void LlamaBatch<T>::FreeBuffer()
allocator_->free((void**)&finished_buf_); allocator_->free((void**)&finished_buf_);
allocator_->free((void**)&seq_limit_len_); allocator_->free((void**)&seq_limit_len_);
allocator_->free((void**)&request_output_ids_ptrs_);
allocator_->free((void**)&request_output_ids_lens_);
allocator_->free((void**)&request_seqlen_ptrs_);
allocator_->free((void**)&rope_theta_); allocator_->free((void**)&rope_theta_);
is_allocate_buffer_ = false; is_allocate_buffer_ = false;
...@@ -759,10 +848,6 @@ void LlamaBatch<T>::FreeBuffer() ...@@ -759,10 +848,6 @@ void LlamaBatch<T>::FreeBuffer()
allocator_->free((void**)&h_input_length_buf_, true); allocator_->free((void**)&h_input_length_buf_, true);
allocator_->free((void**)&h_seq_limit_len_, true); allocator_->free((void**)&h_seq_limit_len_, true);
allocator_->free((void**)&h_request_output_ids_ptrs_, true);
allocator_->free((void**)&h_request_output_ids_lens_, true);
allocator_->free((void**)&h_request_seqlen_ptrs_, true);
allocator_->free((void**)&h_output_ids_, true); allocator_->free((void**)&h_output_ids_, true);
is_allocate_persistant_buffer_ = false; is_allocate_persistant_buffer_ = false;
...@@ -770,45 +855,98 @@ void LlamaBatch<T>::FreeBuffer() ...@@ -770,45 +855,98 @@ void LlamaBatch<T>::FreeBuffer()
} }
template<typename T> template<typename T>
LlamaBatch<T>::LlamaBatch(int max_batch_size, LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2<T>* model):
int max_context_token_num, max_batch_size_(params.max_batch_size),
int session_len, max_context_token_num_(params.max_context_token_num),
std::unique_ptr<SequenceManager> sequence_manager, session_len_(params.session_len),
LlamaV2<T>* model):
max_batch_size_(max_batch_size),
max_context_token_num_(max_context_token_num),
session_len_(session_len),
rank_(model->tensor_para_.rank_), rank_(model->tensor_para_.rank_),
debug_(model->debug_), debug_(model->debug_),
step_length_(model->step_length_), step_length_(params.step_length),
sequence_manager_(std::move(sequence_manager)),
model_(model), model_(model),
data_type_(getTensorType<T>()) data_type_(getTensorType<T>()),
num_tokens_per_iter_(params.num_tokens_per_iter),
extra_tokens_per_iter_(params.extra_tokens_per_iter),
max_prefill_iters_(params.max_prefill_iters)
{ {
stream_ = model_->stream_; stream_ = model_->stream_;
allocator_ = model_->allocator_; allocator_ = model_->allocator_;
cublas_wrapper_ = model_->cublas_wrapper_; cublas_wrapper_ = model_->cublas_wrapper_;
const size_t elem_bits = (quant_policy & QuantPolicy::kCacheKVInt8) ? 8 : sizeof(T) * 8;
sequence_manager_.reset(new SequenceManager{model_->num_layer_,
model_->local_kv_head_num_,
model_->size_per_head_,
(size_t)cache_block_seq_len,
params.cache_max_block_count,
params.cache_chunk_size,
elem_bits,
model->tensor_para_.rank_,
allocator_});
const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len;
if (max_session_len < session_len_) {
if (rank_ == 0) {
TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.",
session_len_,
max_session_len);
}
session_len_ = max_session_len;
}
for (auto& s : states_) { for (auto& s : states_) {
s.requests.resize(max_batch_size); s.requests.resize(max_batch_size_);
s.sequences.resize(max_batch_size); s.sequences.resize(max_batch_size_);
s.seq_len_limit.resize(max_batch_size); s.seq_len_limit.resize(max_batch_size_);
s.is_swap_in.resize(max_batch_size);
} }
state_ = &states_[0]; state_ = &states_[0];
back_ = &states_[1]; back_ = &states_[1];
incoming_ = &states_[2]; incoming_ = &states_[2];
AllocateBuffer(max_batch_size, session_len_); AllocateBuffer(max_batch_size_, session_len_);
AllocatePersistantBuffer(max_batch_size); AllocatePersistantBuffer(max_batch_size_);
} }
template<typename T> template<typename T>
void LlamaBatch<T>::InitializeSampling() void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
{ {
NvtxScope _("InitSampling"); NvtxScope _("InitSampling");
const int batch_size = state_->active_size; const int batch_size = state_->active_size - g.partial;
if (batch_size == 0) {
return;
}
// Context length at initialization, will stay constant until re-initialziation
Copy(context_length_buf_, batch_size, init_context_length_);
Copy(context_length_buf_, batch_size, sequence_lengths_);
// `sequence_lengths_` will be increased by dynamic decode
// note that in decoder and in output "sequence length" has different semantic
// - in decoder it means length of sequence that has kv cache already computed
// - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
sync_check_cuda_error();
Clear(token_ids_buf_, batch_size * session_len_);
invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
sync_check_cuda_error();
// token_ids_buf_[s, b]
// ABCDe ABCDe e
// ABCDEFGHIJk ABCDEFGHIJk
// ABCDEFGHi -> ABCDEFGHi i
// ABCDEFGh ABCDEFGh h
// ABCd ABCd d
invokePadLastTokenIds(token_ids_buf_, init_context_length_, g.max_init_ctx_len, batch_size, stream_);
sync_check_cuda_error();
// seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted for
for (int i = 0; i < batch_size; ++i) {
h_seq_limit_len_[i] = state_->seq_len_limit[i] + (g.max_init_ctx_len - state_->h_context_length[i]);
}
Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
TensorMap inputs; TensorMap inputs;
for (const auto& [name, h_ptr, d_ptr] : sampling_params_) { for (const auto& [name, h_ptr, d_ptr] : sampling_params_) {
// find an exemplar that matches the param name // find an exemplar that matches the param name
...@@ -828,6 +966,7 @@ void LlamaBatch<T>::InitializeSampling() ...@@ -828,6 +966,7 @@ void LlamaBatch<T>::InitializeSampling()
const int size_in_bytes = ref.sizeBytes(); const int size_in_bytes = ref.sizeBytes();
memset(h_ptr, 0, size_in_bytes * batch_size); memset(h_ptr, 0, size_in_bytes * batch_size);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
FT_CHECK(state_->requests[i] != nullptr);
if (state_->requests[i]->inputs[rank_].isExist(name)) { if (state_->requests[i]->inputs[rank_].isExist(name)) {
Tensor& src = state_->requests[i]->inputs[rank_].at(name); Tensor& src = state_->requests[i]->inputs[rank_].at(name);
FT_CHECK(ref.shape == src.shape); FT_CHECK(ref.shape == src.shape);
...@@ -854,343 +993,6 @@ void LlamaBatch<T>::InitializeSampling() ...@@ -854,343 +993,6 @@ void LlamaBatch<T>::InitializeSampling()
model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_); model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);
} }
template<typename T>
auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
{
NvtxScope _("InitGen");
const int batch_size = state_->active_size;
const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
Copy(state_->h_context_length, batch_size, context_length_buf_); // also referenced in `SetOutputTensors`
Copy(context_length_buf_, batch_size, sequence_lengths_);
// `sequence_lengths_` will be increased by dynamic decode
// note that in decoder and in output "sequence length" has different semantic
// - in decoder it means length of sequence that has kv cache already computed
// - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
sync_check_cuda_error();
Clear(token_ids_buf_, batch_size * session_len_);
invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
sync_check_cuda_error();
// token_ids_buf_[s, b]
// ABCDe ABCDe e
// ABCDEFGHIJk ABCDEFGHIJk
// ABCDEFGHi -> ABCDEFGHi i
// ABCDEFGh ABCDEFGh h
// ABCd ABCd d
invokePadLastTokenIds(token_ids_buf_, context_length_buf_, max_context_len, batch_size, stream_);
sync_check_cuda_error();
// used for dispatching split-k decoding kernels
const int sum_seq_len =
std::accumulate(state_->h_context_length, state_->h_context_length + batch_size, -batch_size);
const int max_seq_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size) - 1;
// seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted
// for
for (int i = 0; i < batch_size; ++i) {
h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len - state_->h_context_length[i]);
}
Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
Copy(state_->h_finished, batch_size, finished_buf_);
for (int i = 0; i < batch_size; ++i) {
Tensor& output_ids = state_->requests[i]->outputs[rank_].at("output_ids");
int* req_output_ids_ptr = output_ids.getPtr<int>();
int* req_seqlen_ptr = state_->requests[i]->outputs[rank_].getPtr<int>("sequence_length");
h_request_output_ids_ptrs_[i] = req_output_ids_ptr;
h_request_output_ids_lens_[i] = output_ids.shape.at(2);
h_request_seqlen_ptrs_[i] = req_seqlen_ptr;
FT_CHECK(h_request_output_ids_ptrs_[i]);
FT_CHECK(h_request_output_ids_lens_[i]);
FT_CHECK(h_request_seqlen_ptrs_[i]);
}
Copy(h_request_output_ids_ptrs_, batch_size, request_output_ids_ptrs_);
Copy(h_request_output_ids_lens_, batch_size, request_output_ids_lens_);
Copy(h_request_seqlen_ptrs_, batch_size, request_seqlen_ptrs_);
Copy(state_->h_rope_theta, batch_size, rope_theta_);
// ! range of step_ [1, 2 * session_len]
// consider a sequence with context_len == session_len and another sequence with context_len == 1 and
// request_output_len == session_len - 1 => step_ will loop in [session_len, 2 * session_len)
const int start_step = max_context_len;
if (rank_ == 0) {
TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size);
TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len);
if (debug_) {
TM_LOG_INFO("[initGen] slot sequence_id context_len seq_limit_len finished");
for (int i = 0; i < batch_size; ++i) {
TM_LOG_INFO("[initGen] %4d %11ld %11d %13d %8d",
i,
(long)state_->sequences[i]->id,
state_->h_context_length[i],
(int)h_seq_limit_len_[i],
(int)state_->h_finished[i]);
}
}
}
return GenerationState{max_context_len, start_step, sum_seq_len, max_seq_len};
}
template<typename T>
bool LlamaBatch<T>::Generate(GenerationState& g)
{
NvtxScope scope("Generate");
const int batch_size = state_->active_size;
constexpr int kLogInterval = 10;
if (rank_ == 0 && (g.step - 1) % kLogInterval == 0) {
TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1);
}
const bool is_first_step = (g.step == g.max_init_ctx_len);
std::vector<int> prev;
if (debug_ && rank_ == 0 && is_first_step) {
prev.resize(batch_size);
Copy(token_ids_buf_ + (g.step - 1) * batch_size, batch_size, prev.data());
}
// embeddingLookup(step_ - 1);
model_->embeddingLookup(decoder_input_buf_, //
token_ids_buf_,
batch_size,
g.step - 1);
model_->decoderForward(decoder_output_buf_,
k_block_ptrs_,
v_block_ptrs_,
decoder_input_buf_,
sequence_lengths_,
finished_buf_,
cu_block_counts_,
rope_theta_,
g.step,
0,
g.sum_seq_len,
g.max_seq_len,
batch_size);
model_->postDecodeEmbedding(logits_buf_, //
local_logits_buf_,
decoder_output_buf_,
batch_size);
/// sync for better NVTX visualization, THIS IS NOT NEEDED
// check_cuda_error(cudaStreamSynchronize(stream_));
// stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
// not supported yet.
bool should_stop{};
model_->dynamicDecode(token_ids_buf_,
finished_buf_,
sequence_lengths_,
&should_stop,
state_->curand_state,
&inputs_,
&outputs_,
logits_buf_,
seq_limit_len_,
context_length_buf_,
d_end_ids_buf_,
g.step,
0,
g.max_init_ctx_len,
session_len_ * 2,
batch_size);
if (debug_ && rank_ == 0) {
std::vector<int> curr(batch_size);
Copy(token_ids_buf_ + g.step * batch_size, batch_size, curr.data());
cudaStreamSynchronize(stream_);
if (is_first_step) {
std::stringstream sprev;
for (int k = 0; k < prev.size(); ++k) {
sprev << std::setw(6) << prev[k];
}
TM_LOG_INFO("[ lookup ] step = %d, [%s]", g.step - 1, sprev.str().c_str());
}
std::stringstream scurr;
for (int k = 0; k < curr.size(); ++k) {
scurr << std::setw(6) << curr[k];
}
TM_LOG_INFO("[generate] step = %d, [%s]", g.step - 1, scurr.str().c_str());
}
////////////////////////////////////////////////
/// ! increase the counters
g.step += 1;
g.max_seq_len += 1;
g.sum_seq_len += batch_size;
return !should_stop;
}
template<typename T>
void LlamaBatch<T>::ContextDecode()
{
NvtxScope _("prefill");
const auto batch_size = state_->active_size;
int base = -1;
for (int i = 0; i < batch_size; ++i) {
if (state_->is_swap_in[i]) {
const auto& seq = *state_->sequences[i];
dbg(std::tuple(i, state_->h_context_length[i], seq.cache_len));
if (const int missing = state_->h_context_length[i] - seq.cache_len; missing > 1) {
base = base < 0 ? i : base;
dbg(seq.tokens, seq.cache_len);
Copy(state_->output_ids + i * session_len_ + seq.cache_len, missing, input_ids_buf_ + i * session_len_);
// subtract input/context len by 1 to skip last input token (will process with decoder later)
h_input_length_buf_[i] = missing - 1;
}
}
}
if (base < 0) {
// TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
return;
}
const int context_decode_count = batch_size - base;
Copy(state_->h_context_length, batch_size, context_length_buf_);
Copy(state_->h_rope_theta, batch_size, rope_theta_);
Copy(h_input_length_buf_, batch_size, input_length_buf_);
// check_cuda_error(cudaStreamSynchronize(stream_));
// const auto tick = std::chrono::high_resolution_clock::now();
if (rank_ == 0) {
TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
}
// subtract input/context len by 1 to skip last input token (will process with decoder later)
invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_);
// find sub-batch offsets
std::vector<int> offsets{base};
std::vector<int> max_context_cnts;
int accum_size = 0;
int accum_input_count = 0;
int max_context_count = 0;
for (int i = base; i < batch_size; ++i) {
int size = accum_size + 1;
int input_count = accum_input_count + h_input_length_buf_[i];
int context_count = std::max(max_context_count, state_->h_context_length[i] - 1);
// we have `cu_seqlens` on q so no padding for input is needed
// kernels are expecting uniform k/v cache length -> `max_context_count * size <= max_context_token_num_`
if (input_count <= max_context_token_num_ && context_count * size <= max_context_token_num_) {
accum_size = size;
accum_input_count = input_count;
max_context_count = context_count;
}
else {
offsets.push_back(i);
max_context_cnts.push_back(max_context_count);
accum_size = 1;
accum_input_count = h_input_length_buf_[i];
max_context_count = state_->h_context_length[i] - 1;
}
}
offsets.push_back(batch_size);
max_context_cnts.push_back(max_context_count);
dbg(offsets, max_context_cnts);
// context decode on sub-batches
for (int k = 0; k < offsets.size() - 1; ++k) {
int first = offsets[k];
int last = offsets[k + 1];
int sub_batch_size = last - first;
T* k_ptr = tmp_k_cache_buf_;
T* v_ptr = tmp_v_cache_buf_;
std::vector<int> decode_indices{};
std::vector<int> decode_lengths{};
int max_input_len{};
auto input_ids = context_decoder_ids_buf_;
TM_LOG_INFO("first = %d, last = %d", first, last);
for (int i = first; i < last; ++i) {
// TM_LOG_INFO("session_len = %d, input_length = %d", session_len_, h_input_length_buf_[i]);
input_ids = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
dbg(i, h_input_length_buf_[i]);
h_tmp_k_ptrs_[i] = k_ptr;
h_tmp_v_ptrs_[i] = v_ptr;
k_ptr += model_->local_kv_head_num_ * max_context_cnts[k] * model_->size_per_head_;
v_ptr += model_->local_kv_head_num_ * max_context_cnts[k] * model_->size_per_head_;
decode_indices.push_back(i);
decode_lengths.push_back(h_input_length_buf_[i]);
max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
}
int token_count = input_ids - context_decoder_ids_buf_;
dbg(token_count, max_input_len, max_context_cnts[k]);
Copy(h_tmp_k_ptrs_ + first, sub_batch_size, tmp_k_ptrs_ + first);
Copy(h_tmp_v_ptrs_ + first, sub_batch_size, tmp_v_ptrs_ + first);
if (rank_ == 0) {
TM_LOG_INFO(
"[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
base,
sub_batch_size,
token_count,
max_input_len,
max_context_cnts[k]);
}
dbg(first, last);
dbg(k_block_ptrs_, v_block_ptrs_);
model_->contextDecode(nullptr,
k_block_ptrs_,
v_block_ptrs_,
tmp_k_ptrs_ + first,
tmp_v_ptrs_ + first,
context_decoder_input_buf_,
context_decoder_output_buf_,
context_decoder_ids_buf_,
input_length_buf_ + first,
context_length_buf_ + first,
cu_block_counts_ + first,
rope_theta_ + first,
token_count,
max_input_len,
max_context_cnts[k],
max_context_cnts[k],
sub_batch_size);
// compute logits of inputs if requested
OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
}
invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_);
std::fill(h_input_length_buf_ + base, h_input_length_buf_ + batch_size, 0);
// `SequenceManager` needs real-time value of cache length
for (int i = base; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i] - 1; // -1 since we skip last token
}
}
// check_cuda_error(cudaStreamSynchronize(stream_));
// const auto tock = std::chrono::high_resolution_clock::now();
// if (rank_ == 0) {
// TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count());
// }
}
template<typename T> template<typename T>
void LlamaBatch<T>::OutputContextLogits(T* context_decoder_output, void LlamaBatch<T>::OutputContextLogits(T* context_decoder_output,
const std::vector<int>& indices, const std::vector<int>& indices,
...@@ -1239,21 +1041,25 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_ ...@@ -1239,21 +1041,25 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_
} }
template<typename T> template<typename T>
auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vector<Signal> auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
{ {
NvtxScope scope("Finish"); NvtxScope scope("Finish");
const int batch_size = state_->active_size; const int batch_size = state_->active_size;
// [s,b] -> [b,s] and skip padding in [context_len, max_context_len) if (batch_size - g.partial) {
invokeGatherOutput(state_->output_ids, FT_CHECK(g.step >= 0);
token_ids_buf_,
context_length_buf_, // [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
g.max_init_ctx_len, invokeGatherOutput(state_->output_ids,
g.step, token_ids_buf_,
session_len_, init_context_length_,
batch_size, g.max_init_ctx_len,
stream_); g.step,
sync_check_cuda_error(); session_len_,
batch_size - g.partial,
stream_);
sync_check_cuda_error();
}
Copy(state_->output_ids, batch_size * session_len_, h_output_ids_); Copy(state_->output_ids, batch_size * session_len_, h_output_ids_);
Copy(finished_buf_, batch_size, state_->h_finished); Copy(finished_buf_, batch_size, state_->h_finished);
...@@ -1261,15 +1067,6 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect ...@@ -1261,15 +1067,6 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
check_cuda_error(cudaStreamSynchronize(stream_)); check_cuda_error(cudaStreamSynchronize(stream_));
// `SequenceManager` needs real-time value of cache length
// ! Must be done before incrementing `h_context_length` because the generated token is NOT kv-cached yet
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}
// invariant: context_length = sequence_length + 1, so that h_context_length include all (including the one just // invariant: context_length = sequence_length + 1, so that h_context_length include all (including the one just
// generated) tokens // generated) tokens
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
...@@ -1278,34 +1075,42 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect ...@@ -1278,34 +1075,42 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
{ // set output tokens ids and sequence length { // set output tokens ids and sequence length
int* output_ptr = h_output_ids_; int* output_ptr = h_output_ids_;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size - g.partial; ++i) {
if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) { if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
const int count = state_->h_context_length[i]; auto output_ids = state_->requests[i]->outputs[rank_].getPtr<int>("output_ids");
// TODO: sync history output tokens at when receiving the request and copy only the last token here auto output_len = state_->requests[i]->outputs[rank_].getPtr<int>("sequence_length");
std::copy(output_ptr, output_ptr + count, h_request_output_ids_ptrs_[i]); const int count = state_->h_context_length[i];
*h_request_seqlen_ptrs_[i] = count; // TODO: sync history output tokens at when receiving the request and copy the last token here
std::copy(output_ptr, output_ptr + count, output_ids);
*output_len = count;
} }
output_ptr += session_len_; output_ptr += session_len_;
} }
} }
if (debug_ && rank_ == 0) { if (debug_ && rank_ == 0) {
std::stringstream ss;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")"; // ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")";
std::vector<int> tokens(state_->h_context_length[i]);
Copy(state_->output_ids + i * session_len_, tokens.size(), tokens.data());
cudaStreamSynchronize(stream_);
std::stringstream ss;
for (const auto& t : tokens) {
ss << " " << t;
}
TM_LOG_INFO("[Finish] slot %d, tokens [%s]", i, ss.str().c_str());
} }
TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
} }
std::vector<Signal> signals; std::vector<Signal> signals;
{ {
NvtxScope _("stream_and_completion_signal"); NvtxScope _("stream_and_completion_signal");
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size - g.partial; ++i) {
if (state_->requests[i]) { if (state_->requests[i]) {
if (state_->h_finished[i]) { if (state_->h_finished[i]) {
// Interrupt finished sequences and move the request handle into the signal closure // Interrupt finished sequences and move the request handle into the signal closure
signals.push_back(Interrupt(i)); signals.push_back(Interrupt(i));
++finished_count; ++g.finished_count;
} }
else if (state_->requests[i]->stream_cb) { else if (state_->requests[i]->stream_cb) {
// Create signals by copying the request handles for non-finished streaming requests // Create signals by copying the request handles for non-finished streaming requests
...@@ -1317,11 +1122,18 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect ...@@ -1317,11 +1122,18 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
} }
} }
} }
if (finished_count) { if (g.finished_count) {
// synchronize for interrupted sequences // synchronize for interrupted sequences
check_cuda_error(cudaStreamSynchronize(stream_)); check_cuda_error(cudaStreamSynchronize(stream_));
} }
} }
if (g.partial) {
const int i = batch_size - 1;
// recover full context length of partial
state_->h_context_length[i] = g.partial_context_legnth;
}
return signals; return signals;
} }
...@@ -1387,9 +1199,6 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id) ...@@ -1387,9 +1199,6 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
auto& infer_requests = shared_state->infer_requests; auto& infer_requests = shared_state->infer_requests;
auto& stop_requests = shared_state->stop_requests; auto& stop_requests = shared_state->stop_requests;
// sequences that are removed but still counted in state's size
int finished_count = 0;
GenerationState g{}; GenerationState g{};
constexpr int request_interval = 1; constexpr int request_interval = 1;
...@@ -1397,7 +1206,7 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id) ...@@ -1397,7 +1206,7 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
while (1) { while (1) {
if (rank_ == 0) { if (rank_ == 0) {
const int free_slot_count = max_batch_size_ - state_->size + finished_count; const int free_slot_count = max_batch_size_ - state_->size + g.finished_count;
const bool is_empty = (free_slot_count == max_batch_size_); const bool is_empty = (free_slot_count == max_batch_size_);
stop_requests.clear(); stop_requests.clear();
infer_requests.clear(); infer_requests.clear();
...@@ -1430,33 +1239,27 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id) ...@@ -1430,33 +1239,27 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
SendSignals(std::move(signals)); SendSignals(std::move(signals));
auto modified = Initialize(); Initialize(g);
// finished sequences is handled by `Initialize()`
finished_count = 0;
if (state_->active_size) {
ContextDecode();
if (modified) { FT_CHECK(step_length_ == 1);
g = InitializeGeneration();
InitializeSampling();
}
if (state_->active_size) {
for (int i = 0; i < step_length_; ++i) { for (int i = 0; i < step_length_; ++i) {
if (!Generate(g)) { //
break; auto cont = Forward(g, i);
//
if (auto signals = Finish(g); !signals.empty()) {
if (g.finished_count) {
// Finished requests and corresponding output tensors will be released when notified
// wait for all ranks to ensure no rank (except for output thread) will access related
// resources
shared_state->barrier->wait();
}
SendSignals(std::move(signals));
} }
} if (!cont) { // early exit
break;
if (auto signals = Finish(g, finished_count); !signals.empty()) {
if (finished_count) {
// Finished requests and corresponding output tensors will be released when notified
// wait for all ranks to ensure no rank (except for output thread) will access related
// resources
shared_state->barrier->wait();
} }
SendSignals(std::move(signals));
} }
} }
...@@ -1521,6 +1324,235 @@ void LlamaBatch<T>::OutputThreadEntry() ...@@ -1521,6 +1324,235 @@ void LlamaBatch<T>::OutputThreadEntry()
} }
} }
template<typename T>
bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
{
NvtxScope _("Forward");
FT_CHECK(max_context_token_num_ >= max_batch_size_);
const int active_size = state_->active_size;
constexpr int kLogInterval = 10;
if (rank_ == 0 && (g.step - 1) % kLogInterval == 0) {
TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1);
}
int pf_offset = -1;
std::vector<int*> input_d_ptrs(active_size);
if (iter == 0) { // The first iter may have pre-fill tokens
for (int i = 0; i < active_size; ++i) {
const auto& seq = *state_->sequences[i];
// const int missing = state_->h_context_length[i] - seq.cache_len;
FT_CHECK(seq.input_length >= 1);
h_input_length_buf_[i] = seq.input_length;
input_d_ptrs[i] = state_->output_ids + i * session_len_ + seq.cache_len;
if (seq.input_length > 1 && pf_offset < 0) {
pf_offset = i;
}
}
if (pf_offset < 0) {
pf_offset = active_size;
}
}
else {
for (int i = 0; i < active_size; ++i) {
h_input_length_buf_[i] = 1;
input_d_ptrs[i] = state_->output_ids + i * session_len_ + state_->h_context_length[i] - 1;
}
pf_offset = active_size;
}
// These buffers are only accessed when there are prefill workloads
if (pf_offset != active_size) {
Copy(state_->h_context_length, active_size, context_length_buf_);
Copy(h_input_length_buf_, active_size, input_length_buf_);
}
// Find mini-batch offsets: input length > 1 ? prefill() : decode()
// Constraints on mini-batches
// - `context_decoder_input` and `context_decoder_output` can hold `max_context_token_num_` tokens w/o padding
// - prefill() use `tmp_k_cache_buf_` and `tmp_k_cache_buf_`, they can hold `max_context_token_num_` tokens
// but each sequence is padded to the maximum context length in the batch
std::vector<int> offsets{0};
std::vector<int> max_context_cnts;
// initialize first mini-batch with decode tokens
int accum_size = pf_offset;
int accum_token_count = pf_offset;
int max_context_count = 0;
for (int i = pf_offset; i < active_size; ++i) {
FT_CHECK(iter == 0);
int size = accum_size + 1;
int input_count = accum_token_count + h_input_length_buf_[i];
int context_count = std::max(max_context_count, state_->h_context_length[i]);
// correct pre-fill batch size for the first batch
int pf_size = offsets.size() == 1 ? size - pf_offset : size;
// we have `cu_seqlens` on q so no padding for input is needed
// prefill kernels are expecting uniform k/v cache length -> `max_context_count * size <=
// max_context_token_num_`
if (input_count <= max_context_token_num_ && context_count * pf_size <= max_context_token_num_) {
accum_size = size;
accum_token_count = input_count;
max_context_count = context_count;
}
else {
offsets.push_back(i);
max_context_cnts.push_back(max_context_count);
accum_size = 1;
accum_token_count = h_input_length_buf_[i];
max_context_count = state_->h_context_length[i];
}
}
offsets.push_back(active_size);
max_context_cnts.push_back(max_context_count);
// forward on mini-batches
for (int p = 0; p < (int)offsets.size() - 1; ++p) {
int first = offsets[p];
int last = offsets[p + 1];
int mini_batch_size = last - first;
T* k_ptr = tmp_k_cache_buf_;
T* v_ptr = tmp_v_cache_buf_;
int max_input_len{};
auto input_ids = context_decoder_ids_buf_;
//
std::vector<int> decode_indices{};
std::vector<int> decode_lengths{};
BatchedCopy batched_copy;
for (int i = first; i < last; ++i) {
input_ids = batched_copy.Add(input_d_ptrs[i], h_input_length_buf_[i], input_ids);
dbg(i, h_input_length_buf_[i]);
// allocate tmp k/v buffer for pre-fill sequences
if (i < pf_offset) {
h_tmp_k_ptrs_[i] = h_tmp_v_ptrs_[i] = nullptr;
}
else {
h_tmp_k_ptrs_[i] = k_ptr;
h_tmp_v_ptrs_[i] = v_ptr;
k_ptr += model_->local_kv_head_num_ * max_context_cnts[p] * model_->size_per_head_;
v_ptr += model_->local_kv_head_num_ * max_context_cnts[p] * model_->size_per_head_;
}
decode_indices.push_back(i);
decode_lengths.push_back(h_input_length_buf_[i]);
max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
}
int token_count = input_ids - context_decoder_ids_buf_;
batched_copy.Submit(stream_);
Copy(h_tmp_k_ptrs_ + first, mini_batch_size, tmp_k_ptrs_ + first);
Copy(h_tmp_v_ptrs_ + first, mini_batch_size, tmp_v_ptrs_ + first);
const int dc_batch_size = p ? 0 : pf_offset;
const int pf_batch_size = mini_batch_size - dc_batch_size;
if (rank_ == 0) {
if (pf_batch_size) {
TM_LOG_INFO("[Forward] [%d, %d), dc_bsz = %d, pf_bsz = %d, n_tok = %d, max_q = %d, max_k = %d",
first,
last,
dc_batch_size,
pf_batch_size,
token_count,
max_input_len,
max_context_cnts[p]);
}
}
model_->forwardUnified(decoder_output_buf_ + first * model_->hidden_units_,
context_decoder_output_buf_, // temp
context_decoder_input_buf_, // temp
(void**)k_block_ptrs_,
(void**)v_block_ptrs_,
context_decoder_ids_buf_, // temp
cu_block_counts_ + first,
rope_theta_ + first,
finished_buf_ + first,
input_length_buf_ + first,
context_length_buf_ + first,
(T**)tmp_k_ptrs_ + first,
(T**)tmp_v_ptrs_ + first,
token_count,
dc_batch_size,
g.step,
g.sum_seq_len,
g.max_seq_len,
pf_batch_size,
max_input_len,
max_context_cnts[p],
max_context_cnts[p]);
if (iter == 0) {
// compute logits of inputs if requested
OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
}
}
std::fill(h_input_length_buf_, h_input_length_buf_ + active_size, 0);
// `SequenceManager` needs real-time value of cache length
for (int i = 0; i < active_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len += state_->sequences[i]->input_length;
}
}
bool should_stop{};
if (active_size > g.partial) {
model_->postDecodeEmbedding(logits_buf_, local_logits_buf_, decoder_output_buf_, active_size - g.partial);
FT_CHECK(g.step >= 0);
// TM_LOG_INFO("dyn decode bsz %d, partial %d", active_size, g.partial);
// stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
// not supported yet.
model_->dynamicDecode(token_ids_buf_,
finished_buf_,
sequence_lengths_,
&should_stop,
state_->curand_state,
&inputs_,
&outputs_,
logits_buf_,
seq_limit_len_,
init_context_length_,
d_end_ids_buf_,
g.step,
0,
g.max_init_ctx_len,
session_len_ * 2,
active_size - g.partial);
}
if (debug_ && rank_ == 0) {
std::vector<int> curr(active_size);
Copy(token_ids_buf_ + g.step * active_size, active_size, curr.data());
cudaStreamSynchronize(stream_);
std::stringstream scurr;
for (int k = 0; k < curr.size(); ++k) {
scurr << std::setw(6) << curr[k];
}
TM_LOG_INFO("[Forward] step = %d, [%s]", g.step - 1, scurr.str().c_str());
}
// check_cuda_error(cudaStreamSynchronize(stream_));
////////////////////////////////////////////////
/// ! increase the counters
g.step += 1;
g.max_seq_len += 1;
g.sum_seq_len += state_->active_size;
// PrintDecodeTokens(token_ids_buf_, g.step, active_size, stream_, "Forward");
return !should_stop;
}
template class LlamaBatch<half>; template class LlamaBatch<half>;
template class LlamaBatch<float>; template class LlamaBatch<float>;
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "src/turbomind/models/llama/Request.h" #include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h" #include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
...@@ -28,7 +29,6 @@ struct BatchState { ...@@ -28,7 +29,6 @@ struct BatchState {
float* h_rope_theta; float* h_rope_theta;
std::vector<int> seq_len_limit; std::vector<int> seq_len_limit;
std::vector<int> is_swap_in;
std::vector<const Sequence*> sequences; std::vector<const Sequence*> sequences;
std::vector<std::shared_ptr<Request>> requests; std::vector<std::shared_ptr<Request>> requests;
...@@ -42,6 +42,26 @@ struct BatchState { ...@@ -42,6 +42,26 @@ struct BatchState {
template<typename T> template<typename T>
class LlamaV2; class LlamaV2;
struct GenerationState {
int max_init_ctx_len;
int step;
int sum_seq_len;
int max_seq_len;
int partial;
int partial_context_legnth;
std::vector<uint64_t> unique_ids;
int max_input_count1;
int max_input_count2;
std::deque<int> min_input_count;
int finished_count;
};
template<typename T> template<typename T>
class LlamaBatch { class LlamaBatch {
public: public:
...@@ -58,35 +78,24 @@ public: ...@@ -58,35 +78,24 @@ public:
void ProcessInferRequests(const Requests& requests); void ProcessInferRequests(const Requests& requests);
[[nodiscard]] bool Initialize(); void AdjustMaxInputCount(GenerationState& g,
const std::vector<const Sequence*>& sequences,
void ContextDecode(); const std::vector<int>& context_length);
struct GenerationState { void Initialize(GenerationState& g);
int max_init_ctx_len;
int step;
int sum_seq_len;
int max_seq_len;
};
void InitializeSampling(); void InitializeSampling(const GenerationState& g);
GenerationState InitializeGeneration(); [[nodiscard]] bool Forward(GenerationState& g, int iter);
[[nodiscard]] bool Generate(GenerationState& g); [[nodiscard]] auto Finish(GenerationState& g) -> std::vector<Signal>;
[[nodiscard]] auto Finish(GenerationState& g, int& finished_count) -> std::vector<Signal>;
[[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false); [[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false);
void void
OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths); OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
explicit LlamaBatch(int max_batch_size, explicit LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2<T>* model);
int max_context_token_num,
int session_len,
std::unique_ptr<SequenceManager> sequence_manager,
LlamaV2<T>* llama);
~LlamaBatch() ~LlamaBatch()
{ {
...@@ -177,7 +186,7 @@ private: ...@@ -177,7 +186,7 @@ private:
private: private:
const int max_batch_size_; const int max_batch_size_;
const int max_context_token_num_; const int max_context_token_num_;
const int session_len_; int session_len_;
const int rank_; const int rank_;
const bool debug_; const bool debug_;
const int step_length_; const int step_length_;
...@@ -201,6 +210,7 @@ private: ...@@ -201,6 +210,7 @@ private:
// lengths // lengths
int* input_length_buf_{}; // input + cache missed length int* input_length_buf_{}; // input + cache missed length
int* context_length_buf_{}; // history length + input_length int* context_length_buf_{}; // history length + input_length
int* init_context_length_{};
// temp buffers used for block->linear kv-cache conversion // temp buffers used for block->linear kv-cache conversion
T* tmp_k_cache_buf_{}; T* tmp_k_cache_buf_{};
T* tmp_v_cache_buf_{}; T* tmp_v_cache_buf_{};
...@@ -228,13 +238,6 @@ private: ...@@ -228,13 +238,6 @@ private:
int* h_end_ids_buf_{}; int* h_end_ids_buf_{};
int* d_end_ids_buf_{}; int* d_end_ids_buf_{};
int** request_output_ids_ptrs_{};
int* request_output_ids_lens_{};
int** request_seqlen_ptrs_{};
int** h_request_output_ids_ptrs_{};
int* h_request_output_ids_lens_{};
int** h_request_seqlen_ptrs_{};
// pinned buffers // pinned buffers
int* h_input_ids_buf_{}; int* h_input_ids_buf_{};
int* h_input_length_buf_{}; int* h_input_length_buf_{};
...@@ -293,6 +296,10 @@ private: ...@@ -293,6 +296,10 @@ private:
bool output_stop_token_{false}; bool output_stop_token_{false};
int* h_output_ids_{}; int* h_output_ids_{};
const int num_tokens_per_iter_;
const int extra_tokens_per_iter_;
const int max_prefill_iters_;
}; };
} // namespace turbomind } // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "src/turbomind/models/llama/LlamaCacheManager.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
namespace turbomind {
LlamaCacheManager::~LlamaCacheManager()
{
for (auto& p : device_mem_) {
allocator_->free(&p, false);
}
}
void* LlamaCacheManager::allocate(bool is_preallocte)
{
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][allocate]");
}
void* mem_ptr{};
if (!device_free_.empty()) {
mem_ptr = device_free_.front();
device_free_.pop();
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][allocate] free = %d", (int)device_free_.size());
}
}
else if (entry_count_ < max_entry_count_) {
const auto alloc_count = std::min(chunk_size_, max_entry_count_ - entry_count_);
const size_t entry_byte_size = 2 * cache_byte_size_; // 2 for k,v
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][allocate] malloc %d", (int)alloc_count);
}
const auto chunk_ptr = allocator_->malloc(alloc_count * entry_byte_size, false);
FT_CHECK(chunk_ptr);
device_mem_.push_back(chunk_ptr);
entry_count_ += alloc_count;
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][allocate] count = %d", entry_count_);
}
for (int i = 0; i < alloc_count; ++i) {
device_free_.push((uint8_t*)chunk_ptr + entry_byte_size * i);
}
if (!is_preallocte) {
mem_ptr = device_free_.front();
device_free_.pop();
}
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][allocate] free = %d", (int)device_free_.size());
}
}
else {
mem_ptr = evict();
FT_CHECK_WITH_INFO(mem_ptr, "No enough cache entries.");
}
return mem_ptr;
}
auto LlamaCacheManager::create(uint64_t id, cudaStream_t stream) -> Sequence
{
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][create] %ld", (long)id);
}
for (const auto& e : device_cache_) {
if (e.id == id) {
if (rank_ == 0) {
TM_LOG_WARNING("[LlamaCacheManager][create] Removing conflicting id %ld", (long)id);
}
erase(id);
}
}
const auto mem_ptr = (uint8_t*)allocate(false);
check_cuda_error(cudaMemsetAsync(mem_ptr, 0, cache_byte_size_ * 2, stream));
device_cache_.push_back({
id,
max_seq_len_,
{},
0,
mem_ptr,
mem_ptr + cache_byte_size_,
{},
static_cast<uint64_t>(-1),
});
return device_cache_.back();
}
auto LlamaCacheManager::getEntryOrThrow(uint64_t id) -> std::vector<Sequence>::iterator
{
auto pred = [&](const Sequence& s) { return s.id == id; };
auto it = std::find_if(device_cache_.begin(), device_cache_.end(), pred);
if (it == device_cache_.end()) {
TM_LOG_ERROR("[LlamaCacheManager] %ld not found.\n", (long)id);
FT_CHECK(0);
}
return it;
}
auto LlamaCacheManager::fetch(uint64_t id, cudaStream_t stream) -> Sequence
{
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][fetch] %ld", (long)id);
}
auto entry = getEntryOrThrow(id);
if (entry->k_cache == nullptr) {
FT_CHECK(entry->cache_len == 0);
const auto mem_ptr = allocate(false);
check_cuda_error(cudaMemsetAsync(mem_ptr, 0, cache_byte_size_ * 2, stream));
entry->k_cache = mem_ptr;
entry->v_cache = (uint8_t*)entry->k_cache + cache_byte_size_;
}
entry->timestamp = static_cast<uint64_t>(-1);
return *entry;
}
void LlamaCacheManager::update(const Sequence& seq, cudaStream_t stream)
{
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][update] %ld", (long)seq.id);
}
auto entry = getEntryOrThrow(seq.id);
entry->timestamp = ++timestamp_;
entry->token_ids = seq.token_ids;
entry->cache_len = seq.cache_len;
FT_CHECK(seq.k_cache == entry->k_cache && seq.v_cache == entry->v_cache);
}
void LlamaCacheManager::erase(uint64_t id)
{
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][erase] %ld", (long)id);
}
auto entry = getEntryOrThrow(id);
if (entry->k_cache) {
device_free_.push(entry->k_cache);
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][erase] free = %d", (int)device_free_.size());
}
}
device_cache_.erase(entry);
}
void* LlamaCacheManager::evict()
{
FT_CHECK(!device_cache_.empty());
auto it = std::min_element(device_cache_.begin(), device_cache_.end(), [](const auto& a, const auto& b) {
return a.timestamp < b.timestamp;
});
if (it->timestamp == static_cast<uint64_t>(-1)) {
return nullptr;
}
if (rank_ == 0) {
TM_LOG_INFO("[LlamaCacheManager][evict] %ld", (long)it->id);
}
FT_CHECK(it->k_cache);
auto mem_ptr = it->k_cache;
it->k_cache = it->v_cache = nullptr;
it->cache_len = 0;
it->timestamp = static_cast<uint64_t>(-1);
return mem_ptr;
}
bool LlamaCacheManager::contains(uint64_t id) const noexcept
{
auto pred = [&](const Sequence& s) { return s.id == id; };
auto it = std::find_if(device_cache_.begin(), device_cache_.end(), pred);
return it != device_cache_.end();
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/logger.h"
#include <cstdint>
#include <cuda_runtime.h>
#include <queue>
#include <unordered_map>
#include <vector>
namespace turbomind {
// k-cache layout [L, H, D/x, S[s:], x]
// v-cache layout [L, H, S[s:], D/x, x]
class LlamaCacheManager {
public:
LlamaCacheManager(size_t layer_num,
size_t head_num,
size_t size_per_head,
size_t max_seq_len,
size_t elem_bits,
size_t max_entry_count,
size_t chunk_size,
int rank,
IAllocator* allocator):
layer_num_(layer_num),
head_num_(head_num),
size_per_head_(size_per_head),
max_seq_len_(max_seq_len),
elem_bits_(elem_bits),
cache_byte_size_(layer_num_ * head_num_ * max_seq_len_ * size_per_head_ * elem_bits_ / 8),
max_entry_count_(max_entry_count),
chunk_size_(chunk_size),
rank_(rank),
allocator_(allocator)
{
if (rank == 0) {
TM_LOG_INFO("[LlamaCacheManager] max_entry_count = %d", (int)max_entry_count_);
TM_LOG_INFO("[LlamaCacheManager] chunk_size = %d", (int)chunk_size_);
}
allocate(true);
}
~LlamaCacheManager();
struct Sequence {
// header
uint64_t id;
size_t max_seq_len;
// payloads
std::vector<int> token_ids; // all token ids
size_t cache_len; // cache_len == 0 -> cache miss
void* k_cache;
void* v_cache;
std::vector<uint8_t> random_state_; // states for RNGs
// for LRU policy
uint64_t timestamp;
};
Sequence create(uint64_t id, cudaStream_t stream);
Sequence fetch(uint64_t id, cudaStream_t stream);
void update(const Sequence& seq, cudaStream_t stream);
void erase(uint64_t id);
bool contains(uint64_t id) const noexcept;
private:
std::vector<Sequence>::iterator getEntryOrThrow(uint64_t id);
void* allocate(bool is_preallocte);
void* evict();
private:
const size_t layer_num_{};
const size_t head_num_{};
const size_t size_per_head_{};
const size_t max_seq_len_{};
const size_t elem_bits_{};
const size_t cache_byte_size_{};
const size_t max_entry_count_{};
const size_t chunk_size_{};
const int rank_{};
IAllocator* allocator_{};
std::queue<void*> device_free_;
std::vector<void*> device_mem_;
int entry_count_{};
uint64_t timestamp_{};
std::vector<Sequence> device_cache_;
};
} // namespace turbomind
/*
* Copyright (c) OpenMMLab. All rights reserved.
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptContextDecoder.cc
#include "src/turbomind/models/llama/LlamaContextDecoder.h"
#include "src/turbomind/kernels/bert_preprocess_kernels.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/LlamaContextDecoder.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/debug_utils.h"
namespace turbomind {
template<typename T>
void LlamaContextDecoder<T>::allocateBuffer()
{
FT_CHECK(false);
}
template<typename T>
void LlamaContextDecoder<T>::allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * batch_size * max_q_len * max_kv_len, false);
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * max_q_len, false);
cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false);
is_allocate_buffer_ = true;
}
template<typename T>
void LlamaContextDecoder<T>::freeBuffer()
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) {
allocator_->free((void**)&padding_offset_);
allocator_->free((void**)&cu_seqlens_);
allocator_->free((void**)&attention_mask_);
allocator_->free((void**)&h_pinned_token_num_ptr_, true);
is_allocate_buffer_ = false;
}
}
template<typename T>
void LlamaContextDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
bool use_fmha,
int cache_block_seq_len,
int quant_policy)
{
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
context_attention_layer_ = new LlamaContextAttentionLayer<T>(head_num_,
kv_head_num,
size_per_head_,
attn_params,
tensor_para_,
stream_,
cublas_wrapper_,
allocator_,
is_free_buffer_after_forward_,
use_fmha,
cache_block_seq_len,
quant_policy);
silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
size_per_head_,
inter_size_,
tensor_para_,
stream_,
cublas_wrapper_,
allocator_,
is_free_buffer_after_forward_);
}
template<typename T>
void LlamaContextDecoder<T>::forwardSelfAttn(const Session& sess,
T* attn_io,
std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors,
int layer,
bool is_final)
{
// TM_LOG_ERROR(__PRETTY_FUNCTION__);
TensorMap self_attention_input_tensors{
{"input_query", Tensor{MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
{"attention_mask",
{MEMORY_GPU, data_type_, {sess.batch_size, 1, sess.max_query_len, sess.max_key_len}, attention_mask_}},
{"layer_id", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &layer}},
{"is_final_layer", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_final}},
{"padding_offset", {MEMORY_GPU, TYPE_INT32, {sess.token_num}, padding_offset_}},
{"cu_seqlens", {MEMORY_GPU, TYPE_INT32, {sess.batch_size + 1}, cu_seqlens_}},
{"input_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.input_length}},
{"context_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.context_length}},
{"cu_block_counts", input_tensors->at("cu_block_counts")},
{"rope_theta", input_tensors->at("rope_theta")},
{"max_seq_len", input_tensors->at("max_seq_len")}};
TensorMap self_attention_output_tensors{
{"hidden_features", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
{"key_cache", output_tensors->at("key_cache")},
{"value_cache", output_tensors->at("value_cache")},
{"tmp_k", output_tensors->at("tmp_k")},
{"tmp_v", output_tensors->at("tmp_v")}};
context_attention_layer_->forward(&self_attention_output_tensors, //
&self_attention_input_tensors,
&sess.weights->at(layer)->self_attn_weights);
}
template<typename T>
LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
const LlamaAttentionParams& attn_params,
float rmsnorm_eps,
NcclParam tensor_para,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
bool use_fmha,
int cache_block_seq_len,
int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num),
size_per_head_(size_per_head),
inter_size_(inter_size),
hidden_units_(head_num * size_per_head),
num_layer_(num_layer),
rmsnorm_eps_(rmsnorm_eps),
tensor_para_(tensor_para),
data_type_(getTensorType<T>())
{
initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy);
}
template<typename T>
LlamaContextDecoder<T>::~LlamaContextDecoder()
{
delete context_attention_layer_;
delete silu_ffn_layer_;
freeBuffer();
}
template<typename T>
void LlamaContextDecoder<T>::forward(std::vector<Tensor>* output_tensors,
const std::vector<Tensor>* input_tensors,
const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights)
{
FT_CHECK(false);
}
template<typename T>
void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors,
const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights)
{
/**
* input tensors:
* \param decoder_input [num_token, hidden_units], float
* \param input_lengths [batch_size], int
* \param history_lengths [batch_size], int
* \param context_legnths [batch_size], int
* \param output_norm_weight [hidden_dims], float
* \param max_q_len [1], int on cpu
* \param max_kv_len [1], int on cpu
* \param max_seq_len [1], int on cpu
*
* output tensors:
* \param decoder_output [num_token, hidden_units],
* \param key_cache [num_layer, batch, local_head_num, size_per_head // x, max_seq_len, x]
* \param value_cache [num_layer, batch, local_head_num, max_seq_len, size_per_head]
* \param last_token_hidden_units [batch_size, hidden_units]
*/
Session sess{};
sess.token_num = input_tensors->at("decoder_input").shape[0];
sess.batch_size = input_tensors->at("input_lengths").shape[0];
sess.max_query_len = input_tensors->at("max_q_len").getVal<int>();
sess.max_key_len = input_tensors->at("max_kv_len").getVal<int>();
sess.weights = decoder_layer_weights;
sess.input_length = input_tensors->at("input_lengths").getPtr<int>();
sess.context_length = input_tensors->at("context_lengths").getPtr<int>();
T* decoder_input_output = input_tensors->at("decoder_input").getPtr<T>();
T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();
allocateBuffer(sess.batch_size, sess.token_num, sess.max_query_len, sess.max_key_len);
// dbg(padding_offset_);
FT_CHECK(padding_offset_);
size_t tmp_token_num{};
invokeGetPaddingOffsetAndCuSeqLens(h_pinned_token_num_ptr_,
&tmp_token_num, // updated token num
padding_offset_,
cu_seqlens_,
input_tensors->at("input_lengths").getPtr<int>(),
sess.batch_size,
sess.max_query_len,
stream_);
sync_check_cuda_error();
dbg(tmp_token_num, sess.token_num);
FT_CHECK(tmp_token_num == sess.token_num);
invokeCreateCausalMasks(attention_mask_,
sess.input_length,
sess.context_length,
sess.max_query_len,
sess.max_key_len,
sess.batch_size,
stream_);
sync_check_cuda_error();
// Compare(
// decoder_input_output, sess.token_num * hidden_units_, Concat("context_decoder_input", 0), kCmpRead, stream_);
/////////////////////////////////////////////
/// RMSNorm
invokeRootMeanSquareNorm(decoder_output,
decoder_input_output,
decoder_layer_weights->at(0)->self_attn_norm_weights,
rmsnorm_eps_,
sess.token_num,
hidden_units_,
stream_);
sync_check_cuda_error();
for (size_t layer = 0; layer < num_layer_; ++layer) {
/////////////////////////////////////////////
/// self-attention
forwardSelfAttn(sess, decoder_output, output_tensors, input_tensors, layer, false);
invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
decoder_layer_weights->at(layer)->ffn_norm_weights,
rmsnorm_eps_,
sess.token_num,
hidden_units_,
stream_);
sync_check_cuda_error();
////////////////////////////////////////////
/// feed-forward network
TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
TensorMap ffn_outputs{
{"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights);
auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddBiasResidualRMSNorm(decoder_input_output, //
decoder_output,
decoder_layer_weights->at(layer)->ffn_weights.output.bias,
scale_weight,
rmsnorm_eps_,
sess.token_num,
hidden_units_,
stream_);
sync_check_cuda_error();
}
if (is_free_buffer_after_forward_) {
freeBuffer();
}
}
template class LlamaContextDecoder<float>;
template class LlamaContextDecoder<half>;
} // namespace turbomind
/*
* Copyright (c) OpenMMLab. All rights reserved.
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptContextDecoder.h
#pragma once
#include "src/turbomind/layers/BaseLayer.h"
#include "src/turbomind/models/llama/LlamaContextAttentionLayer.h"
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/custom_ar_comm.h"
#include "src/turbomind/utils/nccl_utils.h"
namespace turbomind {
template<typename T>
class LlamaContextDecoder: public BaseLayer {
protected:
void allocateBuffer() override;
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
void freeBuffer() override;
void initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
bool use_fmha,
int cache_block_seq_len,
int quant_policy);
size_t head_num_;
size_t size_per_head_;
size_t inter_size_;
size_t num_layer_;
size_t hidden_units_;
float rmsnorm_eps_;
NcclParam tensor_para_;
T* attention_mask_{};
int* padding_offset_{};
int* cu_seqlens_{}; // cu for cumulative
size_t* h_pinned_token_num_ptr_{};
LlamaContextAttentionLayer<T>* context_attention_layer_{};
LlamaFfnLayer<T>* silu_ffn_layer_{};
const DataType data_type_;
struct Session {
size_t batch_size;
size_t token_num;
size_t max_query_len;
size_t max_key_len;
int* input_length{};
int* context_length{};
const std::vector<LlamaDecoderLayerWeight<T>*>* weights;
};
void forwardSelfAttn(const Session& sess,
T* attn_io,
std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors,
int layer,
bool is_final);
public:
LlamaContextDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
const LlamaAttentionParams& attn_params,
float rmsnorm_eps,
NcclParam tensor_para,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
bool use_fmha,
int cache_block_seq_len,
int quant_policy);
~LlamaContextDecoder() override;
virtual void forward(std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors,
const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights);
virtual void forward(std::vector<Tensor>* output_tensors,
const std::vector<Tensor>* input_tensors,
const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights);
};
} // namespace turbomind
/*
* Copyright (c) OpenMMLab. All rights reserved.
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022, SK Telecom Authored by A. Dialog
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoder.cc
#include "src/turbomind/models/llama/LlamaDecoder.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/llama_utils.h"
namespace turbomind {
template<typename T>
LlamaDecoder<T>::LlamaDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
const LlamaAttentionParams& attn_params,
float rmsnorm_eps,
NcclParam tensor_para,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
int cache_block_seq_len,
int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num),
size_per_head_(size_per_head),
inter_size_(inter_size),
num_layer_(num_layer),
hidden_units_(head_num * size_per_head),
rmsnorm_eps_(rmsnorm_eps),
tensor_para_(tensor_para),
data_type_(getTensorType<T>())
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
initialize(attn_params, kv_head_num, cache_block_seq_len, quant_policy);
}
template<typename T>
LlamaDecoder<T>::~LlamaDecoder()
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
delete self_attention_layer_;
delete silu_ffn_layer_;
}
template<typename T>
void LlamaDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
int cache_block_seq_len,
int quant_policy)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
self_attention_layer_ = new LlamaDecoderSelfAttentionLayer<T>(head_num_,
kv_head_num,
size_per_head_,
attn_params,
tensor_para_,
stream_,
cublas_wrapper_,
allocator_,
is_free_buffer_after_forward_,
cache_block_seq_len,
quant_policy);
silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
size_per_head_,
inter_size_,
tensor_para_,
stream_,
cublas_wrapper_,
allocator_,
is_free_buffer_after_forward_);
}
template<typename T>
void LlamaDecoder<T>::allocateBuffer()
{
FT_CHECK(false);
}
template<typename T>
void LlamaDecoder<T>::allocateBuffer(size_t batch_size)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
is_allocate_buffer_ = true;
}
template<typename T>
void LlamaDecoder<T>::freeBuffer()
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) {
is_allocate_buffer_ = false;
}
}
template<typename T>
void LlamaDecoder<T>::forwardSelfAttn(const LlamaDecoder::Session& sess,
T* attn_io,
const std::unordered_map<std::string, Tensor>* input_tensors,
size_t layer)
{
NvtxScope scope("self_attn");
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
TensorMap self_attention_input_tensors(*input_tensors);
self_attention_input_tensors.insert("input_query",
{MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, attn_io});
const int layer_id = layer;
self_attention_input_tensors.insert("layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id});
auto& k_cache = *sess.k_cache;
auto& v_cache = *sess.v_cache;
TensorMap self_attention_output_tensors{
{"attention_output", {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, attn_io}},
{"key_cache", k_cache},
{"value_cache", v_cache},
};
self_attention_layer_->forward(&self_attention_output_tensors, //
&self_attention_input_tensors,
&sess.weights->at(layer)->self_attn_weights);
}
template<typename T>
void LlamaDecoder<T>::forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer)
{
TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, ffn_io}}};
TensorMap ffn_outputs{{"ffn_output", {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, ffn_io}}};
silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &sess.weights->at(layer)->ffn_weights);
}
template<typename T>
void LlamaDecoder<T>::forward(std::vector<Tensor>* output_tensors,
const std::vector<Tensor>* input_tensors,
const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights)
{
FT_CHECK(false);
}
template<typename T>
void LlamaDecoder<T>::forward(std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors,
const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
/**
* input_tensors:
* \param decoder_input [batch_size, hidden_dims]
* \param sequence_lengths [batch_size] int
* \param output_norm_weight [hidden_dims]
* \param step [1] on cpu
* \param ite [1] on cpu
* \param finished [batch_size] bool
* \param total_padding_tokens [batch_size], int
* \param max_seq_len [1] on cpu
* \param masked_tokens [batch_size, memory_len] bool (optional), NOT USED YET
*
* output_tensors:
* \param decoder_output [batch_size, hidden_dimension]
* \param key_cache [batch_size] uint64_t
* \param value_cache [batch_size] uint64_t
*/
// for the shape of key cache, refer to decoder_masked_multihead_attention_template.hpp
NvtxScope forward_scope("decoder_forward");
Session sess{};
sess.batch_size = input_tensors->at("decoder_input").shape[0];
sess.weights = decoder_layer_weights;
allocateBuffer(sess.batch_size);
sess.k_cache = &output_tensors->at("key_cache");
sess.v_cache = &output_tensors->at("value_cache");
T* decoder_input = input_tensors->at("decoder_input").getPtr<T>();
T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();
int step = input_tensors->at("step").getVal<int>();
// Compare(decoder_input, sess.batch_size * hidden_units_, Concat("decoder_input", 0, step), kCmpRead, stream_);
////////////////////////////////////////////
/// RMSNorm
{
NvtxScope rms_norm_scope("rms_norm_0");
invokeRootMeanSquareNorm(decoder_output,
decoder_input,
decoder_layer_weights->at(0)->self_attn_norm_weights,
rmsnorm_eps_,
sess.batch_size,
hidden_units_,
stream_);
sync_check_cuda_error();
}
for (size_t layer = 0; layer < num_layer_; ++layer) {
NvtxScope layer_scope("decode_layer");
// output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_)
forwardSelfAttn(sess, decoder_output, input_tensors, layer);
{
NvtxScope rms_norm_scope("rms_norm_1");
invokeFusedAddBiasResidualRMSNorm(decoder_input,
decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
decoder_layer_weights->at(layer)->ffn_norm_weights,
rmsnorm_eps_,
sess.batch_size,
hidden_units_,
stream_);
sync_check_cuda_error();
}
// decoder_layer_output_ = ffn(decoder_normed_input_)
forwardFfn(sess, decoder_output, layer);
{
NvtxScope rms_norm_scope("rms_norm_2");
auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddBiasResidualRMSNorm(decoder_input, //
decoder_output,
decoder_layer_weights->at(layer)->ffn_weights.output.bias,
scale_weight,
rmsnorm_eps_,
sess.batch_size,
hidden_units_,
stream_);
sync_check_cuda_error();
}
}
if (is_free_buffer_after_forward_) {
freeBuffer();
}
}
template class LlamaDecoder<half>;
template class LlamaDecoder<float>;
} // namespace turbomind
/*
* Copyright (c) OpenMMLab. All rights reserved.
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022, SK Telecom Authored by A. Dialog
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoder.h
#include "src/turbomind/layers/BaseLayer.h"
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/custom_ar_comm.h"
#include "src/turbomind/utils/nccl_utils.h"
namespace turbomind {
template<typename T>
class LlamaDecoder: public BaseLayer {
protected:
void allocateBuffer() override; // deprecated
void allocateBuffer(size_t batch_size);
void freeBuffer() override;
void
initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int cache_block_seq_len, int quant_policy);
size_t head_num_;
size_t size_per_head_;
size_t inter_size_;
size_t num_layer_;
size_t hidden_units_;
float rmsnorm_eps_;
NcclParam tensor_para_;
LlamaDecoderSelfAttentionLayer<T>* self_attention_layer_{};
LlamaFfnLayer<T>* silu_ffn_layer_{};
const DataType data_type_;
struct Session {
size_t batch_size;
Tensor* k_cache;
Tensor* v_cache;
const std::vector<LlamaDecoderLayerWeight<T>*>* weights;
};
void forwardSelfAttn(const Session& sess,
T* attn_io,
const std::unordered_map<std::string, Tensor>* input_tensors,
size_t layer);
void forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer);
public:
LlamaDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
const LlamaAttentionParams& attn_params,
float rmsnorm_eps,
NcclParam tensor_para,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
int cache_block_seq_len,
int quant_policy);
~LlamaDecoder() override;
virtual void forward(std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors,
const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights);
virtual void forward(std::vector<Tensor>* output_tensors,
const std::vector<Tensor>* input_tensors,
const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights);
};
} // namespace turbomind
/*
* Copyright (c) OpenMMLab. All rights reserved.
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/DecoderSelfAttentionLayer.cc
#include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
#include "src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.h"
#include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/nvtx_utils.h"
#include <string>
// #include <glog/logging.h>
namespace turbomind {
template<typename T>
void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
const size_t local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
qkv_buf_ = reinterpret_cast<T*>(
allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * local_q_kv_head_num * size_per_head_, false));
context_buf_ =
reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));
workspace_ = (float*)allocator_->reMalloc(
workspace_, sizeof(float) * batch_size * local_head_num_ * kMaxSplitK * (size_per_head_ + 2), false);
is_allocate_buffer_ = true;
}
template<typename T>
void LlamaDecoderSelfAttentionLayer<T>::freeBuffer()
{
if (is_allocate_buffer_) {
allocator_->free((void**)(&qkv_buf_));
allocator_->free((void**)(&context_buf_));
is_allocate_buffer_ = false;
}
}
template<typename T>
void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* output_tensors,
const TensorMap* input_tensors,
const LlamaAttentionWeight<T>* weights)
{
/**
* input tensors:
* \param input_query [batch_size, hidden_units],
* \param sequence_lengths [batch_size]
* \param step [1] on cpu
* \param finished [batch_size]
* \param total_padding_tokens [batch_size]
* \param layer_id [1], int on cpu
* \param max_seq_len [1] on cpu
* \param masked_tokens [batch_size, memory_len], (optional), NOT USED YET
* \param cache_indirection [batch_size / beam_width, beam_width, memory_max_len] (optional)
*
* output tensors:
* \param attention_output [batch_size, hidden_units],
* \param key_cache [batch, local_head_num, memory_max_len, size_per_head]
* \param value_cache [batch, local_head_num, memory_max_len, size_per_head]
*/
const T* input_query_data = input_tensors->getPtr<T>("input_query");
const int* sequence_lengths_data = input_tensors->getPtr<int>("sequence_lengths");
const bool* finished_data = input_tensors->getPtr<bool>("finished");
const int sum_seq_len = input_tensors->getVal<int>("sum_seq_len");
const int max_seq_len = input_tensors->getVal<int>("max_seq_len");
T* hidden_features_data = output_tensors->getPtr<T>("attention_output");
T** key_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
T** value_cache_ptrs = output_tensors->getPtr<T*>("value_cache");
int* cu_block_counts = input_tensors->at("cu_block_counts").getPtr<int>();
const int layer_id = input_tensors->getVal<int>("layer_id");
const int step = input_tensors->getVal<int>("step");
// const int step_1 = step - 1;
const int batch_size = input_tensors->at("input_query").shape[0];
const float* rope_theta = input_tensors->getPtr<const float>("rope_theta", nullptr);
allocateBuffer(batch_size);
// for (int i = 0; i < batch_size; ++i) {
// if (gSequenceIds(i) == 1) {
// Compare((T*)input_query_data + hidden_units_ * i,
// hidden_units_,
// Concat("query", gSequenceIds(i), seqlens[i], layer_id),
// compare_mode,
// stream_);
// }
// }
{
NvtxScope scope("qkv_gemm");
linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
}
// if (layer_id == 0) {
// Compare(qkv_buf_, batch_size * 3 * hidden_units_, Concat("qkv_buf", step, layer_id), kCmpRead, stream_);
// }
const auto layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_;
// const int memory_len = max_seq_len;
DecoderMultiHeadAttentionParams<T> params{};
params.out = context_buf_;
params.q = qkv_buf_;
params.k = params.q + local_head_num_ * size_per_head_;
params.v = params.k + local_kv_head_num_ * size_per_head_;
params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;
params.q_bias = weights->qkv.bias;
params.k_bias = params.q_bias + local_head_num_ * size_per_head_;
params.v_bias = params.k_bias + local_kv_head_num_ * size_per_head_;
params.batch_size = batch_size;
params.cu_block_cnts = cu_block_counts;
params.k_cache_block_ptrs = (void**)key_cache_ptrs;
params.v_cache_block_ptrs = (void**)value_cache_ptrs;
params.kv_cache_block_size = kv_cache_block_len_;
params.finished = finished_data;
params.per_sample_length = sequence_lengths_data;
params.rope_theta = rope_theta;
params.layer_offset = layer_offset;
params.num_heads = local_head_num_;
params.num_kv_heads = local_kv_head_num_;
params.size_per_head = size_per_head_;
params.inv_sqrt_dh = 1.f / std::sqrt((float)params.size_per_head);
params.rotary_embedding_dim = size_per_head_;
params.rotary_embedding_base = params_.rotary_embedding_base;
params.max_position_embeddings = params_.max_position_embeddings;
// params.use_dynamic_ntk = params_.use_dynamic_ntk;
params.use_logn_attn = params_.use_logn_attn;
params.partial_O = workspace_;
params.partial_M = params.partial_O + batch_size * local_head_num_ * kMaxSplitK * size_per_head_;
params.partial_L = params.partial_M + batch_size * local_head_num_ * kMaxSplitK;
// avg_batch_size = sum_seq_len / max_seq_len
// max_split_k = kMaxSplitK / avg_batch_size
// max_split_k' = min(max_split_k, max_seq_lens / kSliceLen)
const float avg_batch_size = max_seq_len ? (float)sum_seq_len / max_seq_len : 1;
FT_CHECK(avg_batch_size >= 1.f);
const int max_split_k = std::max(1, (int)std::ceil(kMaxSplitK / avg_batch_size));
// if (layer_id == 0) {
// TM_LOG_INFO("avg_batch_size = %.1f, max_split_k = %d", avg_batch_size, max_split_k);
// }
params.max_split_k = max_split_k;
params.max_seq_len = max_seq_len;
params.arch = arch_;
params.stream = stream_;
params.quant_policy = quant_policy_;
std::copy(weights->past_kv_scale.begin(), weights->past_kv_scale.end(), std::begin(params.kv_quant_params));
{
NvtxScope scope("decoder_multihead_attention");
DispatchDecoderMultiheadAttention<T>(params);
}
// for (int i = 0; i < batch_size; ++i) {
// if (gSequenceIds(i) == 1) {
// Compare((T*)context_buf_ + hidden_units_ * i,
// hidden_units_,
// Concat("context_buf", gSequenceIds(i), seqlens[i], layer_id),
// compare_mode,
// stream_);
// }
// }
// if (layer_id == 0) {
// Compare(context_buf_, batch_size * hidden_units_, Concat("context_buf", step, layer_id), kCmpRead, stream_);
// }
{
NvtxScope scope("o_gemm");
linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);
}
if (tensor_para_.world_size_ > 1) {
NcclGuard nccl_guard(tensor_para_, stream_);
ftNcclAllReduceSum(
hidden_features_data, hidden_features_data, batch_size * hidden_units_, tensor_para_, stream_);
sync_check_cuda_error();
}
if (is_free_buffer_after_forward_) {
freeBuffer();
}
// LOG(WARNING);
}
template class LlamaDecoderSelfAttentionLayer<float>;
template class LlamaDecoderSelfAttentionLayer<half>;
} // namespace turbomind
/*
* Copyright (c) OpenMMLab. All rights reserved.
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/DecoderSelfAttentionLayer.h
#pragma once
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/nccl_utils.h"
namespace turbomind {
template<typename T>
class LlamaDecoderSelfAttentionLayer {
public:
void freeBuffer();
void allocateBuffer(size_t batch_size);
LlamaDecoderSelfAttentionLayer(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
const LlamaAttentionParams& attn_params,
NcclParam tensor_para,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
int cache_block_seq_len,
int quant_policy):
head_num_(head_num),
kv_head_num_(kv_head_num),
size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head),
local_head_num_(head_num / tensor_para.world_size_),
local_kv_head_num_(kv_head_num_ / tensor_para.world_size_),
local_hidden_units_(hidden_units_ / tensor_para.world_size_),
params_(attn_params),
tensor_para_(tensor_para),
stream_(stream),
linear_(cublas_wrapper, stream),
allocator_(allocator),
kv_cache_block_len_(cache_block_seq_len),
is_free_buffer_after_forward_(is_free_buffer_after_forward),
quant_policy_(quant_policy)
{
arch_ = getSMVersion();
}
~LlamaDecoderSelfAttentionLayer()
{
freeBuffer();
}
void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaAttentionWeight<T>* weights);
private:
const size_t head_num_;
const size_t kv_head_num_;
const size_t size_per_head_;
const size_t hidden_units_;
const size_t local_head_num_;
const size_t local_kv_head_num_;
const size_t local_hidden_units_;
const size_t kv_cache_block_len_;
const bool is_free_buffer_after_forward_;
const int quant_policy_;
const LlamaAttentionParams& params_;
NcclParam tensor_para_;
cudaStream_t stream_;
IAllocator* allocator_;
LlamaLinear<T> linear_;
T* qkv_buf_ = nullptr;
T* context_buf_ = nullptr;
static constexpr int kMaxSplitK = 16; // must be <= WARP_SIZE
float* workspace_ = nullptr;
bool is_allocate_buffer_{};
int arch_{};
};
} // namespace turbomind
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "src/turbomind/models/llama/SequenceManager.h" #include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/models/llama/unified_decoder.h"
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
...@@ -47,19 +48,14 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -47,19 +48,14 @@ LlamaV2<T>::LlamaV2(size_t head_num,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t vocab_size, size_t vocab_size,
const LlamaAttentionParams& attn_params,
float norm_eps, float norm_eps,
int max_batch_size, const LlamaAttentionParams& attn_params,
int max_context_token_num,
int session_len,
int step_length,
int start_id, int start_id,
int end_id, int end_id,
float cache_max_block_count,
int cache_block_seq_len, int cache_block_seq_len,
int cache_chunk_size,
int quant_policy, int quant_policy,
bool use_context_fmha, bool use_context_fmha,
const EngineParams& engine_params,
std::shared_ptr<SharedState> shared_state, std::shared_ptr<SharedState> shared_state,
LlamaWeight<T>* weights, LlamaWeight<T>* weights,
NcclParam tensor_para, NcclParam tensor_para,
...@@ -89,7 +85,6 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -89,7 +85,6 @@ LlamaV2<T>::LlamaV2(size_t head_num,
is_free_buffer_after_forward_(is_free_buffer_after_forward), is_free_buffer_after_forward_(is_free_buffer_after_forward),
cuda_device_prop_(cuda_device_prop), cuda_device_prop_(cuda_device_prop),
debug_(isDebug()), debug_(isDebug()),
step_length_(step_length),
shared_state_(shared_state) shared_state_(shared_state)
{ {
...@@ -99,38 +94,7 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -99,38 +94,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
vocab_size_padded_ = vocab_size_padded_ =
(vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_; (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
size_t elem_bits = 0; batch_ = std::make_unique<LlamaBatch<T>>(engine_params, cache_block_seq_len, quant_policy, this);
if (quant_policy & QuantPolicy::kCacheKVInt8) {
elem_bits = sizeof(int8_t) * 8;
}
else {
elem_bits = sizeof(T) * 8;
}
const size_t local_kv_head_num = kv_head_num / tensor_para.world_size_;
auto sequence_manager = std::make_unique<SequenceManager>(num_layer,
local_kv_head_num,
size_per_head_,
cache_block_seq_len,
cache_max_block_count,
cache_chunk_size,
elem_bits,
tensor_para_.rank_,
allocator);
const size_t max_session_len = sequence_manager->max_block_count() * cache_block_seq_len;
if (max_session_len < session_len) {
if (tensor_para.rank_ == 0) {
TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.",
session_len,
max_session_len);
}
session_len = max_session_len;
}
batch_ = std::make_unique<LlamaBatch<T>>(
max_batch_size, max_context_token_num, session_len, std::move(sequence_manager), this);
initialize(attn_params, kv_head_num, use_context_fmha, cache_block_seq_len, quant_policy); initialize(attn_params, kv_head_num, use_context_fmha, cache_block_seq_len, quant_policy);
...@@ -141,9 +105,8 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -141,9 +105,8 @@ LlamaV2<T>::LlamaV2(size_t head_num,
template<typename T> template<typename T>
LlamaV2<T>::~LlamaV2() LlamaV2<T>::~LlamaV2()
{ {
delete decoder_; unified_decoder_.reset();
delete dynamic_decode_layer_; delete dynamic_decode_layer_;
delete context_decoder_;
} }
template<typename T> template<typename T>
...@@ -155,36 +118,21 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params, ...@@ -155,36 +118,21 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
context_decoder_ = new LlamaContextDecoder<T>(head_num_, unified_decoder_.reset(new UnifiedDecoder<T>(head_num_,
kv_head_num, kv_head_num,
size_per_head_, size_per_head_,
inter_size_, inter_size_,
num_layer_, num_layer_,
attn_params, attn_params,
rmsnorm_eps_, rmsnorm_eps_,
tensor_para_, tensor_para_,
stream_, stream_,
cublas_wrapper_, cublas_wrapper_,
allocator_, allocator_,
is_free_buffer_after_forward_, is_free_buffer_after_forward_,
use_context_fmha, use_context_fmha,
cache_block_seq_len, cache_block_seq_len,
quant_policy); quant_policy));
decoder_ = new LlamaDecoder<T>(head_num_,
kv_head_num,
size_per_head_,
inter_size_,
num_layer_,
attn_params,
rmsnorm_eps_,
tensor_para_,
stream_,
cublas_wrapper_,
allocator_,
is_free_buffer_after_forward_,
cache_block_seq_len,
quant_policy);
dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_, dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
vocab_size_padded_, vocab_size_padded_,
...@@ -218,31 +166,32 @@ void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba ...@@ -218,31 +166,32 @@ void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba
} }
template<typename T> template<typename T>
void LlamaV2<T>::contextDecode(T* decoder_output, void LlamaV2<T>::forwardUnified(T* out,
uintptr_t* k_cache_ptr, T* decoder_output,
uintptr_t* v_cache_ptr, T* decoder_input,
void** tmp_k_ptrs, void** k_block_ptrs,
void** tmp_v_ptrs, void** v_block_ptrs,
T* context_decoder_input_buf, const int* input_ids,
T* context_decoder_output_buf, const int* cu_block_cnts,
const int* input_ids, const float* rope_theta,
const int* input_length, const bool* dc_finished,
const int* context_length, const int* pf_input_length,
const int* cu_block_counts, const int* pf_context_length,
const float* rope_theta, T** pf_tmp_k_ptrs,
size_t token_num, T** pf_tmp_v_ptrs,
size_t max_input_len, size_t token_num,
size_t max_context_len, int dc_batch_size,
size_t session_len, int dc_step,
size_t batch_size) int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len)
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (tensor_para_.rank_ == 0) { invokeInputIdsEmbeddingLookupPosEncoding(decoder_input,
TM_LOG_INFO("context decoding start");
}
invokeInputIdsEmbeddingLookupPosEncoding(context_decoder_input_buf,
nullptr, // processed somewhere else nullptr, // processed somewhere else
weights_->pre_decoder_embedding_table, weights_->pre_decoder_embedding_table,
static_cast<T*>(nullptr), static_cast<T*>(nullptr),
...@@ -256,81 +205,32 @@ void LlamaV2<T>::contextDecode(T* decoder_output, ...@@ -256,81 +205,32 @@ void LlamaV2<T>::contextDecode(T* decoder_output,
stream_); stream_);
sync_check_cuda_error(); sync_check_cuda_error();
const auto dtype = getTensorType<T>(); const auto dtype = getTensorType<T>();
const auto bsz = batch_size; const size_t bsz = dc_batch_size + pf_batch_size;
const int max_q_len = max_input_len; TensorMap inputs{{"decoder_input", {MEMORY_GPU, dtype, {token_num, hidden_units_}, decoder_input}},
const int max_kv_len = max_context_len; {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
const int max_seq_len = session_len; {"input_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, pf_input_length}},
{"context_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, pf_context_length}},
std::unordered_map<std::string, Tensor> decoder_input_tensors{ {"dc_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &dc_batch_size}},
{"decoder_input", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_input_buf}}, {"dc_sum_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &dc_sum_seq_len}},
{"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}}, {"dc_max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &dc_max_seq_len}},
{"input_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, input_length}}, {"finished", {MEMORY_GPU, TYPE_BOOL, {bsz}, dc_finished}},
{"context_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, context_length}}, {"pf_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &pf_batch_size}},
{"max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_q_len}}, {"pf_max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &pf_max_input_len}},
{"max_kv_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_kv_len}}, {"pf_max_k_len", {MEMORY_CPU, TYPE_INT32, {1}, &pf_max_context_len}},
{"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}}, {"session_len", {MEMORY_CPU, TYPE_INT32, {1}, &pf_session_len}},
{"rope_theta", {MEMORY_GPU, TYPE_FP32, {hidden_units_}, rope_theta}}, {"rope_theta", {MEMORY_GPU, TYPE_FP32, {hidden_units_}, rope_theta}},
{"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {batch_size}, cu_block_counts}}}; {"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {bsz}, cu_block_cnts}}};
std::unordered_map<std::string, Tensor> decoder_output_tensors{ TensorMap outputs{{"decoder_output", {MEMORY_GPU, dtype, {token_num, hidden_units_}, decoder_output}},
{"decoder_output", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_output_buf}}, {"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_block_ptrs}},
{"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_cache_ptr}}, {"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_block_ptrs}},
{"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_cache_ptr}}, {"tmp_k", {MEMORY_GPU, TYPE_UINT64, {bsz}, pf_tmp_k_ptrs}},
{"tmp_k", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_k_ptrs}}, {"tmp_v", {MEMORY_GPU, TYPE_UINT64, {bsz}, pf_tmp_v_ptrs}},
{"tmp_v", {MEMORY_GPU, TYPE_UINT64, {bsz}, tmp_v_ptrs}}, {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, out}}};
{"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, decoder_output}}};
unified_decoder_->forward(&outputs, &inputs, &weights_->decoder_layer_weights);
context_decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &weights_->decoder_layer_weights);
if (tensor_para_.rank_ == 0) {
TM_LOG_INFO("context decoding end");
}
}
template<typename T>
void LlamaV2<T>::decoderForward(T* decoder_output,
uintptr_t* k_cache_ptr,
uintptr_t* v_cache_ptr,
T* decoder_input,
const int* sequence_length,
const bool* finished,
const int* cu_block_counts,
const float* rope_theta,
int step,
int ite,
int sum_seq_len,
int max_seq_len,
size_t batch_size)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
const auto dtype = getTensorType<T>();
// max_input_length is not used w/o linear_bias_slopes
// sequence_lengths_ will be incremented in dynamic decode
std::unordered_map<std::string, Tensor> decoder_input_tensors{
{"decoder_input", {MEMORY_GPU, dtype, {batch_size, hidden_units_}, decoder_input}},
{"sequence_lengths", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
{"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {batch_size}, cu_block_counts}},
{"sum_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &sum_seq_len}},
{"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}},
{"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
{"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
{"rope_theta", {MEMORY_GPU, TYPE_FP32, {batch_size}, rope_theta}},
{"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}},
{"ite", {MEMORY_CPU, TYPE_INT32, {1}, &ite}},
};
// LOG(ERROR) << key_cache_ << " " << value_cache_;
std::unordered_map<std::string, Tensor> decoder_output_tensors{
{"decoder_output", {MEMORY_GPU, dtype, {batch_size, hidden_units_}, decoder_output}},
{"key_cache", {MEMORY_GPU, TYPE_UINT64, {batch_size}, k_cache_ptr}},
{"value_cache", {MEMORY_GPU, TYPE_UINT64, {batch_size}, v_cache_ptr}},
};
decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &weights_->decoder_layer_weights);
} }
template<typename T> template<typename T>
......
...@@ -24,12 +24,11 @@ ...@@ -24,12 +24,11 @@
#include "src/turbomind/layers/DynamicDecodeLayer.h" #include "src/turbomind/layers/DynamicDecodeLayer.h"
#include "src/turbomind/models/llama/Barrier.h" #include "src/turbomind/models/llama/Barrier.h"
#include "src/turbomind/models/llama/LlamaBatch.h" #include "src/turbomind/models/llama/LlamaBatch.h"
#include "src/turbomind/models/llama/LlamaContextDecoder.h"
#include "src/turbomind/models/llama/LlamaDecoder.h"
#include "src/turbomind/models/llama/LlamaWeight.h" #include "src/turbomind/models/llama/LlamaWeight.h"
#include "src/turbomind/models/llama/Request.h" #include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h" #include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/unified_decoder.h"
#include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/instance_comm.h" #include "src/turbomind/utils/instance_comm.h"
...@@ -59,19 +58,14 @@ public: ...@@ -59,19 +58,14 @@ public:
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t vocab_size, size_t vocab_size,
const LlamaAttentionParams& attn_params,
float norm_eps, float norm_eps,
int max_batch_size, const LlamaAttentionParams& attn_params,
int max_context_token_num,
int session_len,
int step_length,
int start_id, int start_id,
int end_id, int end_id,
float cache_max_block_count,
int cache_block_seq_len, int cache_block_seq_len,
int cache_chunk_size,
int quant_policy, int quant_policy,
bool use_context_fmha, bool use_context_fmha,
const EngineParams& engine_params,
std::shared_ptr<SharedState> shared_state, std::shared_ptr<SharedState> shared_state,
LlamaWeight<T>* weights, LlamaWeight<T>* weights,
NcclParam tensor_para, NcclParam tensor_para,
...@@ -113,37 +107,28 @@ private: ...@@ -113,37 +107,28 @@ private:
void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
void contextDecode(T* decoder_output, void forwardUnified(T* out,
uintptr_t* k_block_ptrs, T* decoder_output,
uintptr_t* v_block_ptrs,
void** k_tmp_ptrs,
void** v_tmp_ptrs,
T* context_decoder_input_buf,
T* context_decoder_output_buf,
const int* input_ids,
const int* input_length,
const int* context_length,
const int* cu_block_counts,
const float* rope_theta,
size_t token_num,
size_t max_input_len,
size_t max_context_len,
size_t session_len,
size_t batch_size);
void decoderForward(T* decoder_output,
uintptr_t* k_cache_ptr,
uintptr_t* v_cache_ptr,
T* decoder_input, T* decoder_input,
const int* sequence_length, void** k_block_ptrs,
const bool* finished, void** v_block_ptrs,
const int* cu_block_counts, const int* input_ids,
const int* cu_block_cnts,
const float* rope_theta, const float* rope_theta,
int step, const bool* dc_finished,
int ite, const int* pf_input_length,
int sum_seq_len, const int* pf_context_length,
int max_seq_len, T** pf_tmp_k_ptrs,
size_t batch_size); T** pf_tmp_v_ptrs,
size_t token_num,
int dc_batch_size,
int dc_step,
int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len);
void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size); void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);
...@@ -195,12 +180,11 @@ private: ...@@ -195,12 +180,11 @@ private:
const bool debug_{false}; const bool debug_{false};
LlamaWeight<T>* weights_{}; LlamaWeight<T>* weights_{};
LlamaDecoder<T>* decoder_{};
LlamaContextDecoder<T>* context_decoder_{}; std::unique_ptr<UnifiedDecoder<T>> unified_decoder_;
DynamicDecodeLayer<float>* dynamic_decode_layer_{}; DynamicDecodeLayer<float>* dynamic_decode_layer_{};
const int step_length_;
std::shared_ptr<SharedState> shared_state_; std::shared_ptr<SharedState> shared_state_;
ffi_api_lock_ctrl_t ffi_lock_; ffi_api_lock_ctrl_t ffi_lock_;
std::unique_ptr<LlamaBatch<T>> batch_; std::unique_ptr<LlamaBatch<T>> batch_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment