Commit 44be91d3 authored by zhouxiang's avatar zhouxiang
Browse files

同步新版特性,解决qwen持续输出问题等

parent aefd9f11
...@@ -18,7 +18,7 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") ...@@ -18,7 +18,7 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNOMINMAX -O2 /std:c++17 /arch:AVX /source-charset:utf-8") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNOMINMAX -O2 /std:c++17 /arch:AVX /source-charset:utf-8")
else() else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread --std=c++17 -O2 -g") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread --std=c++17 -O2 -march=native")
endif() endif()
...@@ -42,8 +42,9 @@ if (USE_CUDA) ...@@ -42,8 +42,9 @@ if (USE_CUDA)
#message(${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES}) #message(${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES})
set(FASTLLM_CUDA_SOURCES src/devices/cuda/cudadevice.cpp src/devices/cuda/cudadevicebatch.cpp src/devices/cuda/fastllm-cuda.cu) set(FASTLLM_CUDA_SOURCES src/devices/cuda/cudadevice.cpp src/devices/cuda/cudadevicebatch.cpp src/devices/cuda/fastllm-cuda.cu)
set(FASTLLM_LINKED_LIBS ${FASTLLM_LINKED_LIBS} cublas) set(FASTLLM_LINKED_LIBS ${FASTLLM_LINKED_LIBS} cublas)
set(CMAKE_CUDA_ARCHITECTURES "native")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g --gpu-max-threads-per-block=1024") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g --gpu-max-threads-per-block=1024")
#set(CMAKE_CUDA_ARCHITECTURES "70")
endif() endif()
if (PY_API) if (PY_API)
...@@ -84,6 +85,8 @@ add_custom_command( ...@@ -84,6 +85,8 @@ add_custom_command(
add_executable(benchmark example/benchmark/benchmark.cpp) add_executable(benchmark example/benchmark/benchmark.cpp)
target_link_libraries(benchmark fastllm) target_link_libraries(benchmark fastllm)
add_executable(apiserver example/apiserver/apiserver.cpp example/apiserver/json11.cpp)
target_link_libraries(apiserver fastllm)
add_library(fastllm_tools SHARED ${FASTLLM_CXX_SOURCES} ${FASTLLM_CUDA_SOURCES} tools/src/pytools.cpp) add_library(fastllm_tools SHARED ${FASTLLM_CXX_SOURCES} ${FASTLLM_CUDA_SOURCES} tools/src/pytools.cpp)
target_link_libraries(fastllm_tools PUBLIC ${FASTLLM_LINKED_LIBS}) target_link_libraries(fastllm_tools PUBLIC ${FASTLLM_LINKED_LIBS})
......
...@@ -128,6 +128,8 @@ struct APIConfig { ...@@ -128,6 +128,8 @@ struct APIConfig {
int threads = 4; // 使用的线程数 int threads = 4; // 使用的线程数
bool lowMemMode = false; // 是否使用低内存模式 bool lowMemMode = false; // 是否使用低内存模式
int port = 8080; // 端口号 int port = 8080; // 端口号
int tokens = -1; // token容量限制
int batch = 256; // batch数限制
}; };
void ToNext(char * &cur, const std::string &target, std::string &v) { void ToNext(char * &cur, const std::string &target, std::string &v) {
...@@ -178,13 +180,40 @@ struct HttpRequest { ...@@ -178,13 +180,40 @@ struct HttpRequest {
} }
} }
bool IsValid (char *buffer, int size) {
char *old = buffer;
headers.clear();
ToNext(buffer, " ", method);
ToNext(buffer, " ", route);
ToNext(buffer, "\r\n", type);
while (true) {
if (buffer[0] == 0 || ((long long)(buffer - old)) > 1024 * 1024) {
break;
}
if (buffer[0] == '\r' && buffer[1] == '\n') {
if (headers.find("Content-Length") != headers.end()) {
if (size - ((long long)(buffer - old)) - 2 >= atoi(headers["Content-Length"].c_str())) {
return true;
} else {
return false;
}
}
} else {
std::string key;
ToNext(buffer, ":", key);
ToNext(buffer, "\r\n", headers[key]);
}
}
return false;
}
void Print() { void Print() {
for (auto &it : headers) { for (auto &it : headers) {
printf("%s: %s\n", it.first.c_str(), it.second.c_str()); printf("%s: %s\n", it.first.c_str(), it.second.c_str());
} }
printf("body: %s\n", body.c_str()); printf("body: %s\n", body.c_str());
} }
}; } httpChecker;
struct WorkNode { struct WorkNode {
int client; int client;
...@@ -201,7 +230,7 @@ struct WorkNode { ...@@ -201,7 +230,7 @@ struct WorkNode {
struct WorkQueue { struct WorkQueue {
std::unique_ptr<fastllm::basellm> model; std::unique_ptr<fastllm::basellm> model;
int maxActivateQueryNumber = 128; int maxActivateQueryNumber = 256;
int activateQueryNumber = 0; int activateQueryNumber = 0;
int totalQueryNumber = 0; int totalQueryNumber = 0;
std::mutex locker; std::mutex locker;
...@@ -234,10 +263,12 @@ struct WorkQueue { ...@@ -234,10 +263,12 @@ struct WorkQueue {
WorkNode *now = ts->q.front(); WorkNode *now = ts->q.front();
ts->q.pop(); ts->q.pop();
ts->activateQueryNumber++; ts->activateQueryNumber++;
//ts->totalQueryNumber++;
//printf("totalQueryNumber = %d\n", ts->totalQueryNumber); ts->totalQueryNumber++;
printf("totalQueryNumber = %d\n", ts->totalQueryNumber);
//printf("activate = %d, q.size() = %d\n", ts->activateQueryNumber, (int) ts->q.size()); //printf("activate = %d, q.size() = %d\n", ts->activateQueryNumber, (int) ts->q.size());
new std::thread([](WorkQueue *ts, WorkNode *now) {
std::thread *t = new std::thread([](WorkQueue *ts, WorkNode *now) {
ts->Deal(now); ts->Deal(now);
printf("Response client %d finish\n", now->client); printf("Response client %d finish\n", now->client);
ts->locker.lock(); ts->locker.lock();
...@@ -310,11 +341,13 @@ void Usage() { ...@@ -310,11 +341,13 @@ void Usage() {
std::cout << "<-w|--web> <args>: 网页文件的路径" << std::endl; std::cout << "<-w|--web> <args>: 网页文件的路径" << std::endl;
std::cout << "<-t|--threads> <args>: 使用的线程数量" << std::endl; std::cout << "<-t|--threads> <args>: 使用的线程数量" << std::endl;
std::cout << "<-l|--low>: 使用低内存模式" << std::endl; std::cout << "<-l|--low>: 使用低内存模式" << std::endl;
std::cout << "<--batch>: 最大batch数" << std::endl;
std::cout << "<--tokens>: 最大tokens容量" << std::endl;
std::cout << "<--port> <args>: 网页端口号" << std::endl; std::cout << "<--port> <args>: 网页端口号" << std::endl;
} }
void ParseArgs(int argc, char **argv, APIConfig &config) { void ParseArgs(int argc, char **argv, APIConfig &config) {
std::vector <std::string> sargv; std::vector<std::string> sargv;
for (int i = 0; i < argc; i++) { for (int i = 0; i < argc; i++) {
sargv.push_back(std::string(argv[i])); sargv.push_back(std::string(argv[i]));
} }
...@@ -332,6 +365,10 @@ void ParseArgs(int argc, char **argv, APIConfig &config) { ...@@ -332,6 +365,10 @@ void ParseArgs(int argc, char **argv, APIConfig &config) {
config.webPath = sargv[++i]; config.webPath = sargv[++i];
} else if (sargv[i] == "--port") { } else if (sargv[i] == "--port") {
config.port = atoi(sargv[++i].c_str()); config.port = atoi(sargv[++i].c_str());
} else if (sargv[i] == "--tokens") {
config.tokens = atoi(sargv[++i].c_str());
} else if (sargv[i] == "--batch") {
config.batch = atoi(sargv[++i].c_str());
} else { } else {
Usage(); Usage();
exit(-1); exit(-1);
...@@ -350,6 +387,8 @@ int main(int argc, char** argv) { ...@@ -350,6 +387,8 @@ int main(int argc, char** argv) {
fastllm::SetThreads(config.threads); fastllm::SetThreads(config.threads);
fastllm::SetLowMemMode(config.lowMemMode); fastllm::SetLowMemMode(config.lowMemMode);
workQueue.model = fastllm::CreateLLMModelFromFile(config.path); workQueue.model = fastllm::CreateLLMModelFromFile(config.path);
workQueue.model->tokensLimit = config.tokens;
workQueue.maxActivateQueryNumber = std::max(1, std::min(256, config.batch));
workQueue.Start(); workQueue.Start();
int local_fd = socket(AF_INET, SOCK_STREAM, 0); int local_fd = socket(AF_INET, SOCK_STREAM, 0);
...@@ -375,7 +414,6 @@ int main(int argc, char** argv) { ...@@ -375,7 +414,6 @@ int main(int argc, char** argv) {
listen(local_fd, 2000); listen(local_fd, 2000);
printf("start...\n"); printf("start...\n");
int queuePos = 0; int queuePos = 0;
while (true) { //循环接收客户端的请求 while (true) { //循环接收客户端的请求
//5.创建一个sockaddr_in结构体,用来存储客户机的地址 //5.创建一个sockaddr_in结构体,用来存储客户机的地址
struct sockaddr_in client_addr; struct sockaddr_in client_addr;
...@@ -386,8 +424,19 @@ int main(int argc, char** argv) { ...@@ -386,8 +424,19 @@ int main(int argc, char** argv) {
exit(-1); exit(-1);
} }
int size = read(client, buff, sizeof(buff)); int size = 0;
while (true) {
int cur = read(client, buff + size, sizeof(buff) - size);
size += cur;
if (httpChecker.IsValid(buff, size)) {
break;
}
}
buff[size] = 0; buff[size] = 0;
while (workQueue.q.size() > workQueue.maxActivateQueryNumber) {
sleep(0);
}
workQueue.Push(buff, client); workQueue.Push(buff, client);
} }
......
...@@ -149,6 +149,11 @@ namespace fastllm { ...@@ -149,6 +149,11 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
}; };
class CpuCopyKVCacheOp : BaseOperator {
void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
class CpuSplitBatchOp : BaseBatchOperator { class CpuSplitBatchOp : BaseBatchOperator {
void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
...@@ -180,6 +185,11 @@ namespace fastllm { ...@@ -180,6 +185,11 @@ namespace fastllm {
class CpuCatDirectBatchOp : BaseBatchOperator { class CpuCatDirectBatchOp : BaseBatchOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
}; };
class CpuAttentionBatchOp : BaseBatchOperator {
void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
} }
#endif //FASTLLM_CPUDEVICE_H #endif //FASTLLM_CPUDEVICE_H
...@@ -24,6 +24,11 @@ namespace fastllm { ...@@ -24,6 +24,11 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
}; };
class CudaCopyKVCacheOp : BaseOperator {
void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
class CudaLayerNormOp : BaseOperator { class CudaLayerNormOp : BaseOperator {
bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
...@@ -154,6 +159,11 @@ namespace fastllm { ...@@ -154,6 +159,11 @@ namespace fastllm {
class CudaCatDirectBatchOp : BaseBatchOperator { class CudaCatDirectBatchOp : BaseBatchOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
}; };
class CudaAttentionBatchOp : BaseBatchOperator {
void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
} }
#endif //FASTLLM_CUDADEVICE_H #endif //FASTLLM_CUDADEVICE_H
...@@ -9,6 +9,8 @@ void FastllmCudaMallocBigBuffer(size_t size); ...@@ -9,6 +9,8 @@ void FastllmCudaMallocBigBuffer(size_t size);
void FastllmCudaClearBigBuffer(); void FastllmCudaClearBigBuffer();
void *FastllmCudaMalloc(size_t size); void *FastllmCudaMalloc(size_t size);
void FastllmCudaFree(void *ret); void FastllmCudaFree(void *ret);
void * FastllmCudaDirectMalloc(size_t size);
void FastllmCudaDirectFree(void *ret);
void FastllmCudaCopyFromHostToDevice(void *dst, void *src, size_t size); void FastllmCudaCopyFromHostToDevice(void *dst, void *src, size_t size);
void FastllmCudaCopyFromDeviceToHost(void *dst, void *src, size_t size); void FastllmCudaCopyFromDeviceToHost(void *dst, void *src, size_t size);
...@@ -55,6 +57,8 @@ bool FastllmCudaLlamaRotatePosition2D(fastllm::Data &data, const fastllm::Data & ...@@ -55,6 +57,8 @@ bool FastllmCudaLlamaRotatePosition2D(fastllm::Data &data, const fastllm::Data &
const fastllm::Data &sinData, const fastllm::Data &cosData, int rotaryDim); const fastllm::Data &sinData, const fastllm::Data &cosData, int rotaryDim);
bool FastllmCudaApplyLognAttn (fastllm::Data &input, fastllm::Data &lognAttn, fastllm::Data &positionIds); bool FastllmCudaApplyLognAttn (fastllm::Data &input, fastllm::Data &lognAttn, fastllm::Data &positionIds);
bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Data **v,
fastllm::Data **mask, fastllm::Data **output, int group, float scale, int batch);
bool FastllmCudaSplitBatch(fastllm::Data &input, fastllm::Data **outputs, int axis); bool FastllmCudaSplitBatch(fastllm::Data &input, fastllm::Data **outputs, int axis);
bool FastllmCudaCatBatch(fastllm::Data **inputs, fastllm::Data &output, int axis); bool FastllmCudaCatBatch(fastllm::Data **inputs, fastllm::Data &output, int axis);
bool FastllmCudaMulBatch(fastllm::Data **inputs, float v, int batch, fastllm::Data **outputs); bool FastllmCudaMulBatch(fastllm::Data **inputs, float v, int batch, fastllm::Data **outputs);
......
...@@ -247,6 +247,8 @@ namespace fastllm { ...@@ -247,6 +247,8 @@ namespace fastllm {
long long filePos; long long filePos;
std::shared_ptr<FileMmap> m_file; std::shared_ptr<FileMmap> m_file;
bool directMemory = false; // 直接分配/释放Memory,不经过缓存
Data () {}; Data () {};
Data (DataType type); Data (DataType type);
...@@ -364,6 +366,8 @@ namespace fastllm { ...@@ -364,6 +366,8 @@ namespace fastllm {
void TryMergePairs(std::vector<Symbol> &symbols, int l, int r, std::priority_queue <SymbolPairs> &q); // 插入备选symbol void TryMergePairs(std::vector<Symbol> &symbols, int l, int r, std::priority_queue <SymbolPairs> &q); // 插入备选symbol
int GetRank(std::vector<Symbol> &symbols, std::vector<std::pair<int, int>> &partitions, int idx, int skip);
void Insert(const std::string &s, int tokenId, float score = 1.0f); // 插入一个token void Insert(const std::string &s, int tokenId, float score = 1.0f); // 插入一个token
Data Encode(const std::string &s); // 编码 Data Encode(const std::string &s); // 编码
...@@ -418,9 +422,15 @@ namespace fastllm { ...@@ -418,9 +422,15 @@ namespace fastllm {
void ToDataType(const Data &input, DataType dataType); void ToDataType(const Data &input, DataType dataType);
void CopyKVCache(Data &oldCache, Data &newCache, int oldBsStart, int newBsStart, int bs, int offset);
void Attention(const Data &q, const Data &k, const Data &v, const Data &mask, Data &output, void Attention(const Data &q, const Data &k, const Data &v, const Data &mask, Data &output,
int group, float scale, int attentionType); int group, float scale, int attentionType);
void AttentionBatch(std::vector <Data*> &q, std::vector <Data*> &k, std::vector <Data*> &v,
std::vector <Data*> &mask, std::vector <Data*> &output,
int group, float scale, int attentionType);
void Embedding(const Data &input, Data &weight, Data &output); void Embedding(const Data &input, Data &weight, Data &output);
void RMSNorm(const Data &input, const Data &weight, float eps, Data &output); void RMSNorm(const Data &input, const Data &weight, float eps, Data &output);
......
...@@ -152,5 +152,7 @@ namespace fastllm { ...@@ -152,5 +152,7 @@ namespace fastllm {
std::map <std::string, int> deviceMap; std::map <std::string, int> deviceMap;
std::string adapterName; std::string adapterName;
int tokensLimit = -1;
}; };
} }
...@@ -21,6 +21,12 @@ ...@@ -21,6 +21,12 @@
#ifdef __AVX__ #ifdef __AVX__
#include "immintrin.h" #include "immintrin.h"
#ifdef __GNUC__
#if __GNUC__ < 8
#define _mm256_set_m128i(/* __m128i */ hi, /* __m128i */ lo) \
_mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 0x1)
#endif
#endif
#endif #endif
namespace fastllm { namespace fastllm {
......
...@@ -23,6 +23,7 @@ namespace fastllm { ...@@ -23,6 +23,7 @@ namespace fastllm {
this->ops["ToFloat16"] = (BaseOperator*)(new CpuToFloat16()); this->ops["ToFloat16"] = (BaseOperator*)(new CpuToFloat16());
this->ops["ToFloat32"] = (BaseOperator*)(new CpuToFloat32()); this->ops["ToFloat32"] = (BaseOperator*)(new CpuToFloat32());
this->ops["Attention"] = (BaseOperator*)(new CpuAttention()); this->ops["Attention"] = (BaseOperator*)(new CpuAttention());
this->ops["CopyKVCache"] = (BaseOperator*)(new CpuCopyKVCacheOp());
this->ops["Embedding"] = (BaseOperator*)(new CpuEmbedding()); this->ops["Embedding"] = (BaseOperator*)(new CpuEmbedding());
this->ops["LayerNorm"] = (BaseOperator*)(new CpuLayerNormOp()); this->ops["LayerNorm"] = (BaseOperator*)(new CpuLayerNormOp());
this->ops["RMSNorm"] = (BaseOperator*)(new CpuRMSNormOp()); this->ops["RMSNorm"] = (BaseOperator*)(new CpuRMSNormOp());
...@@ -57,6 +58,7 @@ namespace fastllm { ...@@ -57,6 +58,7 @@ namespace fastllm {
this->ops["MatMulTransBBatch"] = (BaseOperator*)(new CpuMatMulTransBBatchOp()); this->ops["MatMulTransBBatch"] = (BaseOperator*)(new CpuMatMulTransBBatchOp());
this->ops["SoftMaxBatch"] = (BaseOperator*)(new CpuSoftmaxBatchOp()); this->ops["SoftMaxBatch"] = (BaseOperator*)(new CpuSoftmaxBatchOp());
this->ops["CatDirectBatch"] = (BaseOperator*)(new CpuCatDirectBatchOp()); this->ops["CatDirectBatch"] = (BaseOperator*)(new CpuCatDirectBatchOp());
this->ops["AttentionBatch"] = (BaseOperator*)(new CpuAttentionBatchOp());
} }
bool CpuDevice::Malloc(void **ret, size_t size) { bool CpuDevice::Malloc(void **ret, size_t size) {
...@@ -77,7 +79,7 @@ namespace fastllm { ...@@ -77,7 +79,7 @@ namespace fastllm {
return true; return true;
} }
#ifdef __AVX__
#ifdef __AVX2__ #ifdef __AVX2__
int DotU8U8(uint8_t *a, uint8_t *b, int n) { int DotU8U8(uint8_t *a, uint8_t *b, int n) {
__m256i acc = _mm256_setzero_si256(); __m256i acc = _mm256_setzero_si256();
...@@ -105,32 +107,31 @@ namespace fastllm { ...@@ -105,32 +107,31 @@ namespace fastllm {
return ans + I32sum(acc); return ans + I32sum(acc);
}; };
#else //#else
int DotU8U8(uint8_t *a, uint8_t *b, int n) { // int DotU8U8(uint8_t *a, uint8_t *b, int n) {
__m256i acc = _mm256_setzero_si256(); // __m256i acc = _mm256_setzero_si256();
int i = 0; // int i = 0;
int ans = 0; // int ans = 0;
for (; i + 31 < n; i += 32) { // for (; i + 31 < n; i += 32) {
__m256i bx = _mm256_loadu_si256((const __m256i *) (a + i)); // __m256i bx = _mm256_loadu_si256((const __m256i *) (a + i));
__m256i by = _mm256_loadu_si256((const __m256i *) (b + i)); // __m256i by = _mm256_loadu_si256((const __m256i *) (b + i));
__m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0)); // __m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0));
__m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1)); // __m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1));
__m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0)); // __m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0));
__m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1)); // __m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1));
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0)); // acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0));
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1)); // //acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1));
} // }
for (; i < n; i++) { // for (; i < n; i++) {
ans += a[i] * b[i]; // ans += a[i] * b[i];
} // }
return ans + I32sum(acc); // return ans + I32sum(acc);
}; // };
#endif
int DotU4U8(uint8_t *a, uint8_t *b, int n) { int DotU4U8(uint8_t *a, uint8_t *b, int n) {
__m256i acc = _mm256_setzero_si256(); __m256i acc = _mm256_setzero_si256();
...@@ -280,7 +281,7 @@ namespace fastllm { ...@@ -280,7 +281,7 @@ namespace fastllm {
float *qd = (float*)q.cpuData; float *qd = (float*)q.cpuData;
float *kd = (float*)k.cpuData; float *kd = (float*)k.cpuData;
float *vd = (float*)v.cpuData; float *vd = (float*)v.cpuData;
float *maskd = mask.dims.size() > 0 ? (float*)mask.cpuData : nullptr; float *maskd = (datas.find("mask")->second && mask.dims.size() > 0) ? (float*)mask.cpuData : nullptr;
float *od = (float*)output.cpuData; float *od = (float*)output.cpuData;
std::fill(od, od + output.Count(0), 0.0f); std::fill(od, od + output.Count(0), 0.0f);
auto pool = GetPool(); auto pool = GetPool();
...@@ -296,6 +297,30 @@ namespace fastllm { ...@@ -296,6 +297,30 @@ namespace fastllm {
} }
} }
void CpuCopyKVCacheOp::Reshape(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
return;
}
void CpuCopyKVCacheOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &oldCache = *(datas.find("oldCache")->second);
Data &newCache = *(datas.find("newCache")->second);
int oldBsStart = intParams.find("oldBsStart") != intParams.end() ? intParams.find("oldBsStart")->second : -1;
int newBsStart = intParams.find("newBsStart") != intParams.end() ? intParams.find("newBsStart")->second : -1;
int bs = intParams.find("bs") != intParams.end() ? intParams.find("bs")->second : -1;
int offset = intParams.find("offset") != intParams.end() ? intParams.find("offset")->second : -1;
int unitSize = oldCache.unitSize;
for (int o = 0; o < bs; o++) {
uint8_t *cur = newCache.cpuData + (newBsStart + o) * newCache.strides[0] * unitSize;
cur += offset * newCache.strides[1] * unitSize;
uint8_t *old = oldCache.cpuData + (oldBsStart + o) * oldCache.strides[0] * unitSize;
memcpy(cur, old, oldCache.dims[1] * oldCache.dims[2] * unitSize);
}
}
void CpuEmbedding::Reshape(const std::string &opType, const fastllm::DataDict &datas, void CpuEmbedding::Reshape(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second); Data &input = *(datas.find("input")->second);
...@@ -894,7 +919,7 @@ namespace fastllm { ...@@ -894,7 +919,7 @@ namespace fastllm {
c[block * kstride + i] = value; c[block * kstride + i] = value;
} }
} }
#elif defined(__AVX__) #elif defined(__AVX2__)
int block = 0; int block = 0;
for (; block < n; block++) { for (; block < n; block++) {
uint8_t *weightWalk = b; uint8_t *weightWalk = b;
...@@ -968,7 +993,7 @@ namespace fastllm { ...@@ -968,7 +993,7 @@ namespace fastllm {
sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0]));
} }
value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; value += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#elif defined(__AVX__) #elif defined(__AVX2__)
value += DotU4U8(weightWalk + i * m / 2, inputWalk, m); value += DotU4U8(weightWalk + i * m / 2, inputWalk, m);
j += m; j += m;
#endif #endif
...@@ -1039,7 +1064,7 @@ namespace fastllm { ...@@ -1039,7 +1064,7 @@ namespace fastllm {
sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0]));
} }
value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; value += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#elif defined(__AVX__) #elif defined(__AVX2__)
value += DotU4U8(weightWalk + i * m / 2, inputWalk, m); value += DotU4U8(weightWalk + i * m / 2, inputWalk, m);
j += m; j += m;
#endif #endif
......
...@@ -202,4 +202,45 @@ namespace fastllm { ...@@ -202,4 +202,45 @@ namespace fastllm {
} }
delete op; delete op;
} }
void CpuAttentionBatchOp::Reshape(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data **qs = (Data**)(datas.find("q")->second);
Data **ks = (Data**)(datas.find("k")->second);
Data **vs = (Data**)(datas.find("v")->second);
Data **outputs = (Data**)(datas.find("output")->second);
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int batch = intParams.find("q___batch")->second;
Data &q = *qs[0], &k = *ks[0], &v = *vs[0];
AssertInFastLLM(q.dims.size() == 3 && k.dims.size() == 3 && v.dims.size() == 3, "Attention: dims of q, k, v should be 3.\n");
AssertInFastLLM(q.dims[2] == k.dims[2], "Attention: q.dims[2] should be equal to k.dims[2].\n");
AssertInFastLLM(k.dims[1] == v.dims[1], "Attention: k.dims[1] should be equal to v.dims[1].\n");
AssertInFastLLM(k.dims[0] == v.dims[0], "Attention: k.dims[0] should be equal to v.dims[0].\n");
AssertInFastLLM(q.dims[0] == k.dims[0] * group, "Attention: q.dims[0] should be equal to k.dims[0] * group.\n");
AssertInFastLLM(q.dataType == k.dataType && q.dataType == v.dataType,
"Attention: q, k, v's datatype should be same.\n");
AssertInFastLLM(q.dataType == DataType::FLOAT32, "Attention's input's type should be float32.\n");
for (int i = 0; i < batch; i++) {
outputs[i]->dataType = qs[i]->dataType;
outputs[i]->Resize({qs[i]->dims[0], qs[i]->dims[1], vs[i]->dims[2]});
}
}
void CpuAttentionBatchOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
fastllm::BaseOperator *op = (fastllm::BaseOperator*)(new CpuAttention());
int batch = intParams.find("q___batch")->second;
DataDict tempDatas = datas;
for (int i = 0; i < batch; i++) {
tempDatas["q"] = ((Data**)datas.find("q")->second)[i];
tempDatas["k"] = ((Data**)datas.find("k")->second)[i];
tempDatas["v"] = ((Data**)datas.find("v")->second)[i];
tempDatas["mask"] = ((Data**)datas.find("mask")->second)[i];
tempDatas["output"] = ((Data**)datas.find("output")->second)[i];
op->Run("Attention", tempDatas, floatParams, intParams);
}
delete op;
}
} }
\ No newline at end of file
...@@ -13,6 +13,7 @@ namespace fastllm { ...@@ -13,6 +13,7 @@ namespace fastllm {
CudaDevice::CudaDevice() { CudaDevice::CudaDevice() {
this->deviceType = "cuda"; this->deviceType = "cuda";
this->ops["Attention"] = (BaseOperator*)(new CudaAttention()); this->ops["Attention"] = (BaseOperator*)(new CudaAttention());
this->ops["CopyKVCache"] = (BaseOperator*)(new CudaCopyKVCacheOp());
this->ops["LayerNorm"] = (BaseOperator*)(new CudaLayerNormOp()); this->ops["LayerNorm"] = (BaseOperator*)(new CudaLayerNormOp());
this->ops["RMSNorm"] = (BaseOperator*)(new CudaRMSNormOp()); this->ops["RMSNorm"] = (BaseOperator*)(new CudaRMSNormOp());
this->ops["Linear"] = (BaseOperator*)(new CudaLinearOp()); this->ops["Linear"] = (BaseOperator*)(new CudaLinearOp());
...@@ -43,6 +44,7 @@ namespace fastllm { ...@@ -43,6 +44,7 @@ namespace fastllm {
this->ops["MatMulTransBBatch"] = (BaseOperator*)(new CudaMatMulTransBBatchOp()); this->ops["MatMulTransBBatch"] = (BaseOperator*)(new CudaMatMulTransBBatchOp());
this->ops["SoftMaxBatch"] = (BaseOperator*)(new CudaSoftmaxBatchOp()); this->ops["SoftMaxBatch"] = (BaseOperator*)(new CudaSoftmaxBatchOp());
this->ops["CatDirectBatch"] = (BaseOperator*)(new CudaCatDirectBatchOp()); this->ops["CatDirectBatch"] = (BaseOperator*)(new CudaCatDirectBatchOp());
this->ops["AttentionBatch"] = (BaseOperator*)(new CudaAttentionBatchOp());
} }
bool CudaDevice::Malloc(void **ret, size_t size) { bool CudaDevice::Malloc(void **ret, size_t size) {
...@@ -90,10 +92,11 @@ namespace fastllm { ...@@ -90,10 +92,11 @@ namespace fastllm {
void CudaAttention::Run(const std::string &opType, const fastllm::DataDict &datas, void CudaAttention::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data emptyData;
Data &q = *(datas.find("q")->second); Data &q = *(datas.find("q")->second);
Data &k = *(datas.find("k")->second); Data &k = *(datas.find("k")->second);
Data &v = *(datas.find("v")->second); Data &v = *(datas.find("v")->second);
Data &mask = *(datas.find("mask")->second); Data &mask = datas.find("mask")->second ? *(datas.find("mask")->second) : emptyData;
Data &output = *(datas.find("output")->second); Data &output = *(datas.find("output")->second);
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1; int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
float scale = floatParams.find("scale") != floatParams.end() ? floatParams.find("scale")->second : 1.0; float scale = floatParams.find("scale") != floatParams.end() ? floatParams.find("scale")->second : 1.0;
...@@ -101,6 +104,31 @@ namespace fastllm { ...@@ -101,6 +104,31 @@ namespace fastllm {
FastllmCudaAttention(q, k, v, mask, output, group, scale); FastllmCudaAttention(q, k, v, mask, output, group, scale);
} }
void CudaCopyKVCacheOp::Reshape(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
return;
}
void CudaCopyKVCacheOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &oldCache = *(datas.find("oldCache")->second);
Data &newCache = *(datas.find("newCache")->second);
int oldBsStart = intParams.find("oldBsStart") != intParams.end() ? intParams.find("oldBsStart")->second : -1;
int newBsStart = intParams.find("newBsStart") != intParams.end() ? intParams.find("newBsStart")->second : -1;
int bs = intParams.find("bs") != intParams.end() ? intParams.find("bs")->second : -1;
int offset = intParams.find("offset") != intParams.end() ? intParams.find("offset")->second : -1;
int unitSize = oldCache.unitSize;
FastllmCudaMemcpy2DDeviceToDevice((uint8_t *) newCache.cudaData + newBsStart * newCache.strides[0] * unitSize
+ offset * newCache.strides[1] * unitSize,
newCache.strides[0] * unitSize,
(uint8_t *) oldCache.cudaData + oldBsStart * oldCache.strides[0] * unitSize,
oldCache.strides[0] * unitSize,
oldCache.dims[1] * oldCache.dims[2] * unitSize, bs);
}
bool CudaRMSNormOp::CanRun(const std::string &opType, const fastllm::DataDict &datas, bool CudaRMSNormOp::CanRun(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
return true; return true;
......
...@@ -311,4 +311,46 @@ namespace fastllm { ...@@ -311,4 +311,46 @@ namespace fastllm {
FastllmCudaMemcpy2DDeviceToDeviceBatch(dsts.data(), dpitchs.data(), srcs.data(), FastllmCudaMemcpy2DDeviceToDeviceBatch(dsts.data(), dpitchs.data(), srcs.data(),
spitchs.data(), widths.data(), heights.data(), dsts.size()); spitchs.data(), widths.data(), heights.data(), dsts.size());
} }
void CudaAttentionBatchOp::Reshape(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data **qs = (Data**)(datas.find("q")->second);
Data **ks = (Data**)(datas.find("k")->second);
Data **vs = (Data**)(datas.find("v")->second);
Data **outputs = (Data**)(datas.find("output")->second);
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int batch = intParams.find("q___batch")->second;
Data &q = *qs[0], &k = *ks[0], &v = *vs[0];
AssertInFastLLM(q.dims.size() == 3 && k.dims.size() == 3 && v.dims.size() == 3, "Attention: dims of q, k, v should be 3.\n");
AssertInFastLLM(q.dims[2] == k.dims[2], "Attention: q.dims[2] should be equal to k.dims[2].\n");
AssertInFastLLM(k.dims[1] == v.dims[1], "Attention: k.dims[1] should be equal to v.dims[1].\n");
AssertInFastLLM(k.dims[0] == v.dims[0], "Attention: k.dims[0] should be equal to v.dims[0].\n");
AssertInFastLLM(q.dims[0] == k.dims[0] * group, "Attention: q.dims[0] should be equal to k.dims[0] * group.\n");
AssertInFastLLM(q.dataType == k.dataType && q.dataType == v.dataType,
"Attention: q, k, v's datatype should be same.\n");
AssertInFastLLM(q.dataType == DataType::FLOAT32, "Attention's input's type should be float32.\n");
for (int i = 0; i < batch; i++) {
outputs[i]->dataType = qs[i]->dataType;
outputs[i]->Resize({qs[i]->dims[0], qs[i]->dims[1], vs[i]->dims[2]});
}
}
void CudaAttentionBatchOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
int batch = intParams.find("q___batch")->second;
Data **qs = (Data**)(datas.find("q")->second);
Data **ks = (Data**)(datas.find("k")->second);
Data **vs = (Data**)(datas.find("v")->second);
Data **masks = (Data**)(datas.find("mask")->second);
Data **outputs = (Data**)(datas.find("output")->second);
for (int i = 0; i < batch; i++) {
outputs[i]->Allocate();
}
FastllmCudaAttentionBatch(qs, ks, vs, masks, outputs,
intParams.find("group")->second,
floatParams.find("scale")->second,
intParams.find("q___batch")->second);
}
} }
\ No newline at end of file
...@@ -800,6 +800,63 @@ __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) { ...@@ -800,6 +800,63 @@ __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) {
int input0Stride = (int)((size_t)pointer[id * 8 + 6]); int input0Stride = (int)((size_t)pointer[id * 8 + 6]);
int input1Stride = (int)((size_t)pointer[id * 8 + 7]); int input1Stride = (int)((size_t)pointer[id * 8 + 7]);
int tid = threadIdx.x;
int pera = 4, perb = 4;
float cura[4][4], curb[4][4], curc[4][4];
int cnta = (n - 1) / pera + 1, cntb = (k - 1) / perb + 1;
for (int taskId = tid; taskId < cnta * cntb; taskId += THREAD_PER_BLOCK) {
int taska = taskId / cntb, taskb = taskId % cntb;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
cura[i][j] = 0;
curb[i][j] = 0;
curc[i][j] = 0;
}
}
for (int l = 0; l < m; l += 4) {
for (int a = taska * pera; a < (taska + 1) * pera && a < n; a++) {
#pragma unroll
for (int x = 0; x < 4; x++) {
cura[a - taska * pera][x] = input0[a * input0Stride + l + x];
}
}
for (int b = taskb * perb; b < (taskb + 1) * perb && b < k; b++) {
#pragma unroll
for (int x = 0; x < 4; x++) {
curb[b - taskb * perb][x] = input1[b * input1Stride + l + x];
}
}
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
#pragma unroll
for (int k = 0; k < 4; k++) {
curc[i][j] += cura[i][k] * curb[j][k];
}
}
}
}
if ((taska + 1) * pera <= n && (taskb + 1) * perb <= k) {
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
output[(taska * pera + i) * k + (taskb * perb + j)] = curc[i][j] * alpha;
}
}
} else {
for (int i = 0; i < pera && taska * pera + i < n; i++) {
for (int j = 0; j < perb && taskb * perb + j < k; j++) {
output[(taska * pera + i) * k + (taskb * perb + j)] = curc[i][j] * alpha;
}
}
}
}
/*
int tid = threadIdx.x; int tid = threadIdx.x;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
float *curInput0 = input0 + i * input0Stride; float *curInput0 = input0 + i * input0Stride;
...@@ -812,6 +869,7 @@ __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) { ...@@ -812,6 +869,7 @@ __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) {
output[i * k + j] = sum * alpha; output[i * k + j] = sum * alpha;
} }
} }
*/
} }
template <int THREAD_PER_BLOCK> template <int THREAD_PER_BLOCK>
...@@ -827,6 +885,64 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) { ...@@ -827,6 +885,64 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) {
int input1Stride = (int)((size_t)pointer[id * 8 + 7]); int input1Stride = (int)((size_t)pointer[id * 8 + 7]);
int tid = threadIdx.x; int tid = threadIdx.x;
int pera = 4, perb = 4;
float cura[4][4], curb[4][4], curc[4][4];
int cnta = (n - 1) / pera + 1, cntb = (k - 1) / perb + 1;
for (int taskId = tid; taskId < cnta * cntb; taskId += THREAD_PER_BLOCK) {
int taska = taskId / cntb, taskb = taskId % cntb;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
cura[i][j] = 0;
curb[i][j] = 0;
curc[i][j] = 0;
}
}
for (int l = 0; l < m; l += 4) {
for (int a = taska * pera; a < (taska + 1) * pera && a < n; a++) {
#pragma unroll
for (int x = 0; x < 4; x++) {
cura[a - taska * pera][x] = l + x < m ? input0[a * input0Stride + l + x] : 0;
}
}
for (int b = taskb * perb; b < (taskb + 1) * perb && b < k; b++) {
#pragma unroll
for (int x = 0; x < 4; x++) {
curb[b - taskb * perb][x] = l + x < m ? input1[(l + x) * input1Stride + b] : 0;
}
}
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
#pragma unroll
for (int k = 0; k < 4; k++) {
curc[i][j] += cura[i][k] * curb[j][k];
}
}
}
}
if ((taska + 1) * pera <= n && (taskb + 1) * perb <= k) {
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
output[(taska * pera + i) * k + (taskb * perb + j)] = curc[i][j] * alpha;
}
}
} else {
for (int i = 0; i < pera && taska * pera + i < n; i++) {
for (int j = 0; j < perb && taskb * perb + j < k; j++) {
output[(taska * pera + i) * k + (taskb * perb + j)] = curc[i][j] * alpha;
}
}
}
}
/*
//int tid = threadIdx.x;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
float *curInput0 = input0 + i * input0Stride; float *curInput0 = input0 + i * input0Stride;
for (int j = tid; j < k; j += THREAD_PER_BLOCK) { for (int j = tid; j < k; j += THREAD_PER_BLOCK) {
...@@ -838,6 +954,7 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) { ...@@ -838,6 +954,7 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) {
output[i * k + j] = sum * alpha; output[i * k + j] = sum * alpha;
} }
} }
*/
} }
template <int THREAD_PER_BLOCK> template <int THREAD_PER_BLOCK>
...@@ -880,6 +997,71 @@ __global__ void FastllmAttentionKernel(float *qd, float *kd, float *vd, float *m ...@@ -880,6 +997,71 @@ __global__ void FastllmAttentionKernel(float *qd, float *kd, float *vd, float *m
} }
} }
template <int THREAD_PER_BLOCK>
__global__ void FastllmAttentionBatchKernel(float** pointer, float scale, int group) {
const int params = 16;
int id = blockIdx.x;
float *qd = (float*) pointer[id * params + 0];
float *kd = (float*) pointer[id * params + 1];
float *vd = (float*) pointer[id * params + 2];
float *maskd = (float*) pointer[id * params + 3];
float *od = (float*) pointer[id * params + 4];
int q1 = (int)(unsigned long long)pointer[id * params + 5];
int q2 = (int)(unsigned long long)pointer[id * params + 6];
int k1 = (int)(unsigned long long)pointer[id * params + 7];
int v2 = (int)(unsigned long long)pointer[id * params + 8];
int qstride = (int)(unsigned long long)pointer[id * params + 9];
int kstride = (int)(unsigned long long)pointer[id * params + 10];
int vstride = (int)(unsigned long long)pointer[id * params + 11];
int ostride = (int)(unsigned long long)pointer[id * params + 12];
float *qk = (float*)pointer[id * params + 13];
float *temp = (float*)pointer[id * params + 14];
int q0 = (int)(unsigned long long)pointer[id * params + 15];
for (int o = 0; o < q0; o++) {
qd += o * qstride;
kd += (o / group) * kstride;
vd += (o / group) * vstride;
od += o * ostride;
qk += o * k1;
temp += o * k1;
for (int i = 0; i < q1; i++) {
for (int j = threadIdx.x; j < k1; j += THREAD_PER_BLOCK) {
if (maskd && maskd[i * k1 + j] > 0.99) {
qk[j] = -10000;
continue;
}
float sum = 0.0f;
float *tempQd = qd + i * q2, *tempKd = kd + j * q2;
for (int l = 0; l < q2; l++) {
sum += tempQd[l] * tempKd[l];
}
qk[j] = sum * scale;
}
__syncthreads();
FastllmSoftmaxKernelInner1Func<THREAD_PER_BLOCK>(qk, temp, k1);
__syncthreads();
for (int j = threadIdx.x; j < v2; j += THREAD_PER_BLOCK) {
float *curInput1 = vd + j;
float sum = 0.0;
for (int l = 0; l < k1; l++) {
sum += temp[l] * curInput1[l * v2];
}
od[i * v2 + j] = sum;
}
__syncthreads();
}
qd -= o * qstride;
kd -= (o / group) * kstride;
vd -= (o / group) * vstride;
od -= o * ostride;
qk -= o * k1;
temp -= o * k1;
}
}
void *FastllmCudaPrepareInput(const fastllm::Data &input) { void *FastllmCudaPrepareInput(const fastllm::Data &input) {
void *ret; void *ret;
if (input.dataDevice == fastllm::DataDevice::CUDA) { if (input.dataDevice == fastllm::DataDevice::CUDA) {
...@@ -1294,6 +1476,16 @@ std::map<int, std::vector <CudaMemoryBuffer>> cudaBuffersMap; ...@@ -1294,6 +1476,16 @@ std::map<int, std::vector <CudaMemoryBuffer>> cudaBuffersMap;
std::map<int, size_t> noBusyCnt; std::map<int, size_t> noBusyCnt;
std::map<int, std::vector <CudaMemoryBuffer>> bigBuffersMap; std::map<int, std::vector <CudaMemoryBuffer>> bigBuffersMap;
void * FastllmCudaDirectMalloc(size_t size) {
void * ret;
cudaMalloc(&ret, size);
return ret;
}
void FastllmCudaDirectFree(void *ret) {
cudaFree(ret);
}
void * FastllmCudaMalloc(size_t size) { void * FastllmCudaMalloc(size_t size) {
int id = -1; int id = -1;
cudaGetDevice(&id); cudaGetDevice(&id);
...@@ -1302,7 +1494,7 @@ void * FastllmCudaMalloc(size_t size) { ...@@ -1302,7 +1494,7 @@ void * FastllmCudaMalloc(size_t size) {
int selId = -1; int selId = -1;
for (int i = 0; i < bigBuffers.size(); i++) { for (int i = 0; i < bigBuffers.size(); i++) {
if (bigBuffers[i].size >= size && !bigBuffers[i].busy if (bigBuffers[i].size >= size && !bigBuffers[i].busy
&& bigBuffers[i].size - size < 32 * 1024 * 1024) { && bigBuffers[i].size - size < 1 * 1024 * 1024) {
if (selId == -1 || bigBuffers[selId].size > bigBuffers[i].size) { if (selId == -1 || bigBuffers[selId].size > bigBuffers[i].size) {
selId = i; selId = i;
} }
...@@ -1841,6 +2033,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const ...@@ -1841,6 +2033,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
} }
FastllmCudaFree(qk); FastllmCudaFree(qk);
DeviceSync();
return true; return true;
} }
...@@ -1896,6 +2089,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const ...@@ -1896,6 +2089,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
} }
FastllmCudaFree(qk); FastllmCudaFree(qk);
FastllmCudaFree(temp); FastllmCudaFree(temp);
DeviceSync();
return true; return true;
} }
return true; return true;
...@@ -2044,6 +2238,157 @@ bool FastllmCudaApplyLognAttn (fastllm::Data &input, fastllm::Data &lognAttn, fa ...@@ -2044,6 +2238,157 @@ bool FastllmCudaApplyLognAttn (fastllm::Data &input, fastllm::Data &lognAttn, fa
return true; return true;
} }
bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Data **v,
fastllm::Data **mask, fastllm::Data **output, int group, float scale, int batch) {
int k0 = k[0]->dims[0];
size_t memSum = 0;
for (int b = 0; b < batch; b++) {
memSum += q[b]->dims[0] * q[b]->dims[1] * k[b]->dims[1];
}
float *mem = (float*) FastllmCudaMalloc(memSum * sizeof(float));
float **qk = new float*[batch];
memSum = 0;
for (int b = 0; b < batch; b++) {
int s = q[b]->dims[0] * q[b]->dims[1] * k[b]->dims[1];
qk[b] = mem + memSum;
memSum += s;
}
if (true) {
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * k0 * 8);
uint8_t ** cpuPointers = new uint8_t*[batch * k0 * 8];
for (int b = 0; b < batch; b++) {
for (int i = 0; i < k0; i++) {
cpuPointers[(b * k0 + i) * 8 + 0] = (uint8_t *) q[b]->cudaData + i * group * q[b]->dims[1] * q[b]->dims[2] * sizeof(float);
cpuPointers[(b * k0 + i) * 8 + 1] = (uint8_t *) k[b]->cudaData + i * k[b]->strides[0] * sizeof(float);
cpuPointers[(b * k0 + i) * 8 + 2] = (uint8_t *) qk[b] + i * group * q[b]->dims[1] * k[b]->dims[1] * sizeof(float);
cpuPointers[(b * k0 + i) * 8 + 3] = (uint8_t *) (size_t) (group * q[b]->dims[1]);
cpuPointers[(b * k0 + i) * 8 + 4] = (uint8_t *) (size_t) q[b]->dims[2];
cpuPointers[(b * k0 + i) * 8 + 5] = (uint8_t *) (size_t) k[b]->dims[1];
cpuPointers[(b * k0 + i) * 8 + 6] = (uint8_t *) (size_t) q[b]->strides[1];
cpuPointers[(b * k0 + i) * 8 + 7] = (uint8_t *) (size_t) k[b]->strides[1];
}
}
cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * batch * k0 * 8, cudaMemcpyHostToDevice);
FastllmMatMulTransBBatchKernel <128> <<<batch * k0, 128>>> (pointers, scale);
FastllmCudaFree(pointers);
delete[] cpuPointers;
}
if (true) {
int total = 0;
for (int b = 0; b < batch; b++) {
int outer = q[b]->dims[0] * q[b]->dims[1];
total += outer;
}
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * total * 3);
uint8_t ** cpuPointers = new uint8_t*[total * 3];
int cur = 0;
for (int b = 0; b < batch; b++) {
int outer = q[b]->dims[0] * q[b]->dims[1];
int channels = k[b]->dims[1];
for (int o = 0; o < outer; o++) {
cpuPointers[cur * 3 + 0] = (uint8_t*)(qk[b] + o * channels);
cpuPointers[cur * 3 + 1] = (uint8_t*)(qk[b] + o * channels);
cpuPointers[cur * 3 + 2] = (uint8_t*)((size_t)channels);
cur++;
}
}
cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * total * 3, cudaMemcpyHostToDevice);
FastllmSoftmaxKernelBatchInner1 <256> <<<total, 256>>> (pointers);
FastllmCudaFree(pointers);
delete[] cpuPointers;
}
if (true) {
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * k0 * 8);
uint8_t ** cpuPointers = new uint8_t*[batch * k0 * 8];
for (int b = 0; b < batch; b++) {
for (int i = 0; i < k0; i++) {
cpuPointers[(b * k0 + i) * 8 + 0] = (uint8_t *) qk[b] + i * group * q[b]->dims[1] * k[b]->dims[1] * sizeof(float);
cpuPointers[(b * k0 + i) * 8 + 1] = (uint8_t *) v[b]->cudaData + i * v[b]->strides[0] * sizeof(float);
cpuPointers[(b * k0 + i) * 8 + 2] = (uint8_t *) output[b]->cudaData + i * group * q[b]->dims[1] * v[b]->dims[2] * sizeof(float);
cpuPointers[(b * k0 + i) * 8 + 3] = (uint8_t *) (size_t) (group * q[b]->dims[1]);
cpuPointers[(b * k0 + i) * 8 + 4] = (uint8_t *) (size_t) k[b]->dims[1];
cpuPointers[(b * k0 + i) * 8 + 5] = (uint8_t *) (size_t) v[b]->dims[2];
cpuPointers[(b * k0 + i) * 8 + 6] = (uint8_t *) (size_t) k[b]->dims[1];
cpuPointers[(b * k0 + i) * 8 + 7] = (uint8_t *) (size_t) v[b]->strides[1];
}
}
cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * batch * k0 * 8, cudaMemcpyHostToDevice);
FastllmMatMulKernel <128> <<<batch * k0, 128>>> (pointers, 1.0f);
FastllmCudaFree(pointers);
delete[] cpuPointers;
}
FastllmCudaFree(mem);
delete[] qk;
/*
{
const int params = 16;
float **pointers = (float **) FastllmCudaMalloc(sizeof(float *) * batch * params);
float **cpuPointers = new float *[batch * params];
float **qk = new float *[batch];
float **temp = new float *[batch];
for (int b = 0; b < batch; b++) {
qk[b] = (float *) FastllmCudaMalloc(q[b]->dims[0] * k[b]->dims[1] * sizeof(float));
temp[b] = (float *) FastllmCudaMalloc(q[b]->dims[0] * k[b]->dims[1] * sizeof(float));
cpuPointers[b * params + 0] = (float *) q[b]->cudaData;
cpuPointers[b * params + 1] = (float *) k[b]->cudaData;
cpuPointers[b * params + 2] = (float *) v[b]->cudaData;
cpuPointers[b * params + 3] = (mask[b] && mask[b]->dims.size() > 0) ? (float *) mask[b]->cudaData : nullptr;
cpuPointers[b * params + 4] = (float *) output[b]->cudaData;
cpuPointers[b * params + 5] = (float *) (unsigned long long) q[b]->dims[1];
cpuPointers[b * params + 6] = (float *) (unsigned long long) q[b]->dims[2];
cpuPointers[b * params + 7] = (float *) (unsigned long long) k[b]->dims[1];
cpuPointers[b * params + 8] = (float *) (unsigned long long) v[b]->dims[2];
cpuPointers[b * params + 9] = (float *) (unsigned long long) q[b]->strides[0];
cpuPointers[b * params + 10] = (float *) (unsigned long long) k[b]->strides[0];
cpuPointers[b * params + 11] = (float *) (unsigned long long) v[b]->strides[0];
cpuPointers[b * params + 12] = (float *) (unsigned long long) output[b]->strides[0];
cpuPointers[b * params + 13] = (float *) (unsigned long long) qk[b];
cpuPointers[b * params + 14] = (float *) (unsigned long long) temp[b];
cpuPointers[b * params + 15] = (float *) (unsigned long long) q[b]->dims[0];
}
cudaMemcpy(pointers, cpuPointers, sizeof(float *) * batch * params, cudaMemcpyHostToDevice);
FastllmAttentionBatchKernel<256> <<< batch, 256 >>>(pointers, scale, group);
for (int i = 0; i < batch; i++) {
FastllmCudaFree(qk[i]);
FastllmCudaFree(temp[i]);
}
delete[] qk;
delete[] temp;
FastllmCudaFree(pointers);
delete[] cpuPointers;
}
*/
/*
for (int b = 0; b < batch; b++) {
int q0 = q[b]->dims[0], q1 = q[b]->dims[1], q2 = q[b]->dims[2], k0 = k[b]->dims[0], k1 = k[b]->dims[1], v2 = v[b]->dims[2];
float *qd = (float *) q[b]->cudaData;
float *kd = (float *) k[b]->cudaData;
float *vd = (float *) v[b]->cudaData;
float *maskd = (mask[b] && mask[b]->dims.size() > 0) ? (float *) mask[b]->cudaData : nullptr;
float *od = (float *) output[b]->cudaData;
int maskBatch = (mask[b] && mask[b]->dims.size() > 0) ? mask[b]->dims[0] : 1;
float *qk = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float));
float *temp = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float));
FastllmAttentionKernel<256> <<<q0, 256>>>(qd, kd, vd, maskd, od,
scale, q1, q2, k1, v2,
group, q[b]->strides[0], k[b]->strides[0], v[b]->strides[0],
output[b]->strides[0],
qk, temp);
}
*/
DeviceSync();
return true;
}
bool FastllmCudaSplitBatch(fastllm::Data &input, fastllm::Data **outputs, int axis) { bool FastllmCudaSplitBatch(fastllm::Data &input, fastllm::Data **outputs, int axis) {
int part = input.dims[axis]; int part = input.dims[axis];
int outer = input.Count(0) / input.Count(axis); int outer = input.Count(0) / input.Count(axis);
......
...@@ -69,10 +69,10 @@ namespace fastllm { ...@@ -69,10 +69,10 @@ namespace fastllm {
if (intParams.find(it.first + "___batch") != intParams.end()) { if (intParams.find(it.first + "___batch") != intParams.end()) {
int batch = intParams.find(it.first + "___batch")->second; int batch = intParams.find(it.first + "___batch")->second;
for (int i = 0; i < batch; i++) { for (int i = 0; i < batch; i++) {
lockInCPU |= ((Data**)it.second)[i]->lockInCPU; lockInCPU |= (((Data**)it.second)[i] && ((Data**)it.second)[i]->lockInCPU);
} }
} else { } else {
lockInCPU |= it.second->lockInCPU; lockInCPU |= (it.second && it.second->lockInCPU);
} }
} }
for (auto device: devices) { for (auto device: devices) {
...@@ -89,10 +89,14 @@ namespace fastllm { ...@@ -89,10 +89,14 @@ namespace fastllm {
if (intParams.find(it.first + "___batch") != intParams.end()) { if (intParams.find(it.first + "___batch") != intParams.end()) {
int batch = intParams.find(it.first + "___batch")->second; int batch = intParams.find(it.first + "___batch")->second;
for (int i = 0; i < batch; i++) { for (int i = 0; i < batch; i++) {
((Data**)it.second)[i]->ToDevice((void *) device); if (((Data**)it.second)[i]) {
((Data**)it.second)[i]->ToDevice((void *) device);
}
} }
} else { } else {
it.second->ToDevice((void *) device); if (it.second) {
it.second->ToDevice((void *) device);
}
} }
} }
device->Reshape(opType, datas, floatParams, intParams); device->Reshape(opType, datas, floatParams, intParams);
......
...@@ -368,7 +368,11 @@ namespace fastllm { ...@@ -368,7 +368,11 @@ namespace fastllm {
this->cpuData = new uint8_t[this->expansionBytes]; this->cpuData = new uint8_t[this->expansionBytes];
} else if (this->dataDevice == DataDevice::CUDA) { } else if (this->dataDevice == DataDevice::CUDA) {
#ifdef USE_CUDA #ifdef USE_CUDA
this->cudaData = FastllmCudaMalloc(this->expansionBytes); if (this->directMemory) {
this->cudaData = FastllmCudaDirectMalloc(this->expansionBytes);
} else {
this->cudaData = FastllmCudaMalloc(this->expansionBytes);
}
#else #else
ErrorInFastLLM("Error: cuda is not supported.\n"); ErrorInFastLLM("Error: cuda is not supported.\n");
#endif #endif
...@@ -382,7 +386,11 @@ namespace fastllm { ...@@ -382,7 +386,11 @@ namespace fastllm {
delete[] this->cpuData; delete[] this->cpuData;
} else if (this->dataDevice == DataDevice::CUDA) { } else if (this->dataDevice == DataDevice::CUDA) {
#ifdef USE_CUDA #ifdef USE_CUDA
FastllmCudaFree(this->cudaData); if (this->directMemory) {
FastllmCudaDirectFree(this->cudaData);
} else {
FastllmCudaFree(this->cudaData);
}
#else #else
ErrorInFastLLM("Error: cuda is not supported.\n"); ErrorInFastLLM("Error: cuda is not supported.\n");
#endif #endif
...@@ -415,6 +423,7 @@ namespace fastllm { ...@@ -415,6 +423,7 @@ namespace fastllm {
void Data::Expansion(const std::vector<int> &dims) { void Data::Expansion(const std::vector<int> &dims) {
if (this->dims.size() == 0) { if (this->dims.size() == 0) {
this->directMemory = true;
this->strides.resize(dims.size(), 1); this->strides.resize(dims.size(), 1);
this->strides.back() = 1; this->strides.back() = 1;
for (int i = dims.size() - 2; i >= 0; i--) { for (int i = dims.size() - 2; i >= 0; i--) {
...@@ -489,6 +498,11 @@ namespace fastllm { ...@@ -489,6 +498,11 @@ namespace fastllm {
#ifdef USE_CUDA #ifdef USE_CUDA
if (this->cudaData != nullptr) { if (this->cudaData != nullptr) {
FastllmCudaFree(this->cudaData); FastllmCudaFree(this->cudaData);
/*if (this->directMemory) {
FastllmCudaDirectFree(this->cudaData);
} else {
FastllmCudaFree(this->cudaData);
}*/
} }
#endif #endif
} }
...@@ -524,6 +538,10 @@ namespace fastllm { ...@@ -524,6 +538,10 @@ namespace fastllm {
} }
printf("\n"); printf("\n");
*/ */
// //如果需要打印cuda显存上的数据需要先把数据转到cpu xzhou 20230728
// if (dataDevice == DataDevice::CUDA) {
// ToDevice(DataDevice::CPU);
// }
int n = Count(0) / dims.back(), m = dims.back(); int n = Count(0) / dims.back(), m = dims.back();
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
for (int j = 0; j < 10 && j < m; j++) { for (int j = 0; j < 10 && j < m; j++) {
...@@ -548,7 +566,7 @@ namespace fastllm { ...@@ -548,7 +566,7 @@ namespace fastllm {
weightSum.resize(n); weightSum.resize(n);
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
int j = 0; int j = 0;
#ifdef __AVX__ #ifdef __AVX2__
__m256i acc = _mm256_setzero_si256(); __m256i acc = _mm256_setzero_si256();
const __m256i ones = _mm256_set1_epi16(1); const __m256i ones = _mm256_set1_epi16(1);
for (; j + 31 < m; j += 32) { for (; j + 31 < m; j += 32) {
...@@ -594,7 +612,7 @@ namespace fastllm { ...@@ -594,7 +612,7 @@ namespace fastllm {
} }
weightSum[i] += sum0[0] + sum0[1] + sum0[2] + sum0[3]; weightSum[i] += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#endif #endif
#ifdef __AVX__ #ifdef __AVX2__
__m256i acc = _mm256_setzero_si256(); __m256i acc = _mm256_setzero_si256();
const __m256i lowMask = _mm256_set1_epi8(0xf); const __m256i lowMask = _mm256_set1_epi8(0xf);
const __m256i ones = _mm256_set1_epi16(1); const __m256i ones = _mm256_set1_epi16(1);
...@@ -795,6 +813,18 @@ namespace fastllm { ...@@ -795,6 +813,18 @@ namespace fastllm {
q.push(SymbolPairs(now->score, l, r, symbols[l].len + symbols[r].len)); q.push(SymbolPairs(now->score, l, r, symbols[l].len + symbols[r].len));
} }
int Tokenizer::GetRank(std::vector<Symbol> &symbols, std::vector<std::pair<int, int>> &partitions, int idx, int skip) {
if (idx + skip + 2 >= partitions.size()) {
return std::numeric_limits<int>::max();
}
auto s = symbols[0].s + symbols[0].pos;
std::string key(s + partitions[idx].first, s + partitions[idx + skip + 2].first);
if (stringToTokenDict.find(key) != stringToTokenDict.end()) {
return stringToTokenDict[key];
}
return std::numeric_limits<int>::max();
}
Data Tokenizer::Encode(const std::string &ori) { Data Tokenizer::Encode(const std::string &ori) {
if (this->type == TokenizerType::BPE) { if (this->type == TokenizerType::BPE) {
std::string blank = ""; std::string blank = "";
...@@ -926,48 +956,38 @@ namespace fastllm { ...@@ -926,48 +956,38 @@ namespace fastllm {
if (i == sep.back().first) { if (i == sep.back().first) {
if (!symbols.empty()) { if (!symbols.empty()) {
symbols.back().next = -1; symbols.back().next = -1;
std::priority_queue<SymbolPairs> workQueue; std::string cur = ori.substr(i - symbols.size(), symbols.size());
for (int i = 1; i < symbols.size(); i++) { std::vector<std::pair<int, int>> partitions(symbols.size() + 1);
TryMergePairs(symbols, i - 1, i, workQueue); for (int j = 0; j <= (int) symbols.size(); j++) {
partitions[j] = std::make_pair(j, std::numeric_limits<int>::max());
} }
for (int j = 0; j < partitions.size() - 2; j++) {
while (!workQueue.empty()) { partitions[j].second = GetRank(symbols, partitions, j, 0);
auto top = workQueue.top();
workQueue.pop();
if (symbols[top.l].len == 0 || symbols[top.r].len == 0 ||
symbols[top.l].len + symbols[top.r].len != top.size) {
continue;
}
for (int i = symbols[top.r].pos; i < symbols[top.r].pos + symbols[top.r].len; i++) {
symbols[top.l].node = symbols[top.l].node->next[symbols[top.r].s[i]];
}
symbols[top.l].len += symbols[top.r].len;
symbols[top.r].len = 0;
symbols[top.l].next = symbols[top.r].next;
if (symbols[top.r].next >= 0) {
symbols[symbols[top.r].next].prev = top.l;
}
TryMergePairs(symbols, symbols[top.l].prev, top.l, workQueue);
TryMergePairs(symbols, top.l, symbols[top.l].next, workQueue);
} }
while (partitions.size() > 1) {
for (int i = 0; i < symbols.size(); i++) { int min_rank = std::numeric_limits<int>::max();
if (symbols[i].len > 0) { int min_rank_idx = 0;
v.push_back(symbols[i].node->tokenId); for (int j = 0; j < partitions.size() - 1; ++j) {
} else if (symbols[i].node == nullptr) { if (partitions[j].second < min_rank) {
// 未识别的字符 min_rank = partitions[j].second;
uint8_t c = (uint8_t) (symbols[i].s[symbols[i].pos]); min_rank_idx = j;
std::string now = "<0x00>";
now[3] = (c / 16 > 9 ? ('A' + c / 16 - 10) : ('0' + c / 16));
now[4] = (c % 16 > 9 ? ('A' + c % 16 - 10) : ('0' + c % 16));
if (stringToTokenDict.find(now) != stringToTokenDict.end()) {
v.push_back(stringToTokenDict[now]);
} }
} }
if (min_rank != std::numeric_limits<int>::max()) {
partitions[min_rank_idx].second = GetRank(symbols, partitions, min_rank_idx, 1);
if (min_rank_idx > 0) {
partitions[min_rank_idx - 1].second = GetRank(symbols, partitions, min_rank_idx - 1, 1);
}
partitions.erase(partitions.begin() + min_rank_idx + 1);
} else {
break;
}
} }
symbols.clear(); symbols.clear();
for (int j = 0; j < partitions.size() - 1; j++) {
std::string key = cur.substr(partitions[j].first, partitions[j + 1].first - partitions[j].first);
v.push_back((float) stringToTokenDict[key]);
}
} }
std::string special = ori.substr(sep.back().first, sep.back().second); std::string special = ori.substr(sep.back().first, sep.back().second);
...@@ -1592,6 +1612,14 @@ namespace fastllm { ...@@ -1592,6 +1612,14 @@ namespace fastllm {
} }
} }
void CopyKVCache(Data &oldCache, Data &newCache, int oldBsStart, int newBsStart, int bs, int offset) {
curExecutor->Run("CopyKVCache", {
{"oldCache", (Data*)&oldCache}, {"newCache", (Data*)&newCache}
}, {}, {
{"oldBsStart", oldBsStart}, {"newBsStart", newBsStart}, {"bs", bs}, {"offset", offset}
});
}
void Attention(const Data &q, const Data &k, const Data &v, const Data &mask, Data &output, void Attention(const Data &q, const Data &k, const Data &v, const Data &mask, Data &output,
int group, float scale, int attentionType) { int group, float scale, int attentionType) {
curExecutor->Run("Attention", { curExecutor->Run("Attention", {
...@@ -1814,6 +1842,21 @@ namespace fastllm { ...@@ -1814,6 +1842,21 @@ namespace fastllm {
}, {}, {{"axis", axis}, {"input0___batch", (int)input0.size()}, {"input1___batch", (int)input1.size()}}); }, {}, {{"axis", axis}, {"input0___batch", (int)input0.size()}, {"input1___batch", (int)input1.size()}});
} }
void AttentionBatch(std::vector <Data*> &q, std::vector <Data*> &k, std::vector <Data*> &v,
std::vector <Data*> &mask, std::vector <Data*> &output,
int group, float scale, int attentionType) {
curExecutor->Run("AttentionBatch", {
{"q", (Data*)q.data()}, {"k", (Data*)k.data()}, {"v", (Data*)v.data()},
{"mask", (Data*)mask.data()}, {"output", (Data*)output.data()}
},
{{"scale", scale}},
{
{"group", group},
{"q___batch", (int)q.size()}, {"k___batch", (int)k.size()}, {"v___batch", (int)v.size()},
{"mask___batch", (int)mask.size()}, {"output___batch", (int)output.size()}
});
}
void LoraLayer(Data &input, Data &weight, Data &loraA, Data &loraB, const Data &bias, Data &output, void LoraLayer(Data &input, Data &weight, Data &loraA, Data &loraB, const Data &bias, Data &output,
std::map <std::string, std::string> loraConfig) { std::map <std::string, std::string> loraConfig) {
float r = std::atof(loraConfig["r"].c_str()); float r = std::atof(loraConfig["r"].c_str());
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "basellm.h" #include "basellm.h"
#include "utils.h" #include "utils.h"
#include <sstream> #include <sstream>
#include <cstring>
#ifdef USE_CUDA #ifdef USE_CUDA
#include "fastllm-cuda.cuh" #include "fastllm-cuda.cuh"
...@@ -339,10 +340,24 @@ namespace fastllm { ...@@ -339,10 +340,24 @@ namespace fastllm {
LastTokensManager tokensManager; LastTokensManager tokensManager;
std::vector <std::vector <float>* > logits; std::vector <std::vector <float>* > logits;
model->dictLocker.lock(); model->dictLocker.lock();
int limit = model->tokensLimit > 0 ? model->tokensLimit : 1e9;
int lenSum = 0;
for (auto &it: model->responseContextDict.dicts) {
if (it.second->pastKeyValues[0].first.expansionDims.size() > 0 && !it.second->isEnding) {
lenSum += it.second->pastKeyValues[0].first.expansionDims[1];
}
}
for (int isPrompt = 1; isPrompt >= 0; isPrompt--) { for (int isPrompt = 1; isPrompt >= 0; isPrompt--) {
int cnt = 0;
if (isPrompt == 0 && seqLens.size() > 0) { if (isPrompt == 0 && seqLens.size() > 0) {
continue; continue;
} }
if (lenSum > limit && isPrompt) {
continue;
}
for (auto &it: model->responseContextDict.dicts) { for (auto &it: model->responseContextDict.dicts) {
if (it.second->isEnding) { if (it.second->isEnding) {
continue; continue;
...@@ -350,6 +365,16 @@ namespace fastllm { ...@@ -350,6 +365,16 @@ namespace fastllm {
if (isPrompt && it.second->preTokens != 0) { if (isPrompt && it.second->preTokens != 0) {
continue; continue;
} }
if (!isPrompt && it.second->preTokens == 0) {
continue;
}
int outputLimit = it.second->generationConfig.output_token_limit;
outputLimit = (outputLimit < 0 ? 128 : outputLimit);
if (isPrompt && lenSum + it.second->currentTokens.size() + outputLimit > limit) {
continue;
}
generationConfigs.push_back(it.second->generationConfig); generationConfigs.push_back(it.second->generationConfig);
if (it.second->generationConfig.output_logits) { if (it.second->generationConfig.output_logits) {
it.second->resultLogits.push(new std::vector<float>()); it.second->resultLogits.push(new std::vector<float>());
...@@ -397,6 +422,7 @@ namespace fastllm { ...@@ -397,6 +422,7 @@ namespace fastllm {
&it.second->pastKeyValues[i].second)); &it.second->pastKeyValues[i].second));
} }
if (isPrompt) { if (isPrompt) {
cnt += it.second->currentTokens.size();
break; break;
} }
} }
...@@ -412,6 +438,8 @@ namespace fastllm { ...@@ -412,6 +438,8 @@ namespace fastllm {
#endif #endif
Data inputIds = Data(DataType::FLOAT32, {1, (int) ids.size()}, ids); Data inputIds = Data(DataType::FLOAT32, {1, (int) ids.size()}, ids);
std::vector<int> ret; std::vector<int> ret;
auto st = std::chrono::system_clock::now();
//ClearProfiler();
if (seqLens.size() > 1) { if (seqLens.size() > 1) {
ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks, ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks,
positionIds, seqLens, pastKeyValues, generationConfigs, positionIds, seqLens, pastKeyValues, generationConfigs,
...@@ -422,7 +450,13 @@ namespace fastllm { ...@@ -422,7 +450,13 @@ namespace fastllm {
*positionIds[0], *positionIds[0],
*pastKeyValue1, generationConfigs[0], tokensManager, logits[0])}; *pastKeyValue1, generationConfigs[0], tokensManager, logits[0])};
} }
//PrintProfiler();
/*
static int tot = 0;
printf("len = %d, spend = %f s.\n", (int)seqLens.size(), GetSpan(st, std::chrono::system_clock::now()));
tot += (int)seqLens.size();
printf("tot = %d\n", tot);
*/
model->dictLocker.lock(); model->dictLocker.lock();
for (int i = 0; i < handles.size(); i++) { for (int i = 0; i < handles.size(); i++) {
auto &it = *model->responseContextDict.dicts.find(handles[i]); auto &it = *model->responseContextDict.dicts.find(handles[i]);
......
...@@ -190,7 +190,7 @@ namespace fastllm { ...@@ -190,7 +190,7 @@ namespace fastllm {
if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) { if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) {
newDims = std::vector<int>{k.dims[0], ((k.dims[1] - 1) / unitLen + 1) * unitLen, k.dims[2]}; newDims = std::vector<int>{k.dims[0], ((k.dims[1] - 1) / unitLen + 1) * unitLen, k.dims[2]};
if (generationConfig.output_token_limit > 0) { if (generationConfig.output_token_limit > 0) {
newDims[1] = std::min(newDims[1], k.dims[1] + generationConfig.output_token_limit); newDims[1] = k.dims[1] + generationConfig.output_token_limit;
} }
} else { } else {
newDims = pastKey.dims; newDims = pastKey.dims;
...@@ -207,7 +207,7 @@ namespace fastllm { ...@@ -207,7 +207,7 @@ namespace fastllm {
if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) { if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) {
newDims = std::vector<int>{v.dims[0], ((v.dims[1] - 1) / unitLen + 1) * unitLen, v.dims[2]}; newDims = std::vector<int>{v.dims[0], ((v.dims[1] - 1) / unitLen + 1) * unitLen, v.dims[2]};
if (generationConfig.output_token_limit > 0) { if (generationConfig.output_token_limit > 0) {
newDims[1] = std::min(newDims[1], k.dims[1] + generationConfig.output_token_limit); newDims[1] = k.dims[1] + generationConfig.output_token_limit;
} }
} else { } else {
newDims = pastValue.dims; newDims = pastValue.dims;
...@@ -377,12 +377,12 @@ namespace fastllm { ...@@ -377,12 +377,12 @@ namespace fastllm {
CatDirect(*(Data*)positionIds[0], *(Data*)positionIds[i], 1); CatDirect(*(Data*)positionIds[0], *(Data*)positionIds[i], 1);
} }
} }
std::vector <Data*> keys, values, qs, attns, masks, contexts;
std::vector <Data*> keys, values, qs, attns, contexts;
keys.resize(batch); keys.resize(batch);
values.resize(batch); values.resize(batch);
qs.resize(batch); qs.resize(batch);
attns.resize(batch); attns.resize(batch);
masks.resize(batch);
contexts.resize(batch); contexts.resize(batch);
std::vector <Data*> pointersK, pointersV, pointersQ; std::vector <Data*> pointersK, pointersV, pointersQ;
...@@ -486,6 +486,10 @@ namespace fastllm { ...@@ -486,6 +486,10 @@ namespace fastllm {
auto &q = curQs[b], &k = curKs[b], &v = curVs[b]; auto &q = curQs[b], &k = curKs[b], &v = curVs[b];
Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt + Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt +
i].second; i].second;
if (pastKey.dims.size() > 0 && pastKey.dims[1] + k.dims[1] <= pastKey.expansionDims[1]) {
continue;
}
pastKey.ToDevice(DataDevice::CUDA); pastKey.ToDevice(DataDevice::CUDA);
pastValue.ToDevice(DataDevice::CUDA); pastValue.ToDevice(DataDevice::CUDA);
...@@ -533,64 +537,76 @@ namespace fastllm { ...@@ -533,64 +537,76 @@ namespace fastllm {
} }
CatDirectBatch(keys, pointersK, 1); CatDirectBatch(keys, pointersK, 1);
CatDirectBatch(values, pointersV, 1); CatDirectBatch(values, pointersV, 1);
for (int b = 0; b < batch; b++) {
auto &q = curQs[b];
Data &pastKey = *pastKeyValues[b * block_cnt + i].first;
outputSizes[b] = {1, q.dims[0], q.dims[1], pastKey.dims[1]};
q.Reshape({pastKey.dims[0], -1, q.dims[2]});
}
// 1.2 Attention
// 1.2.0 q * k^T
if (all1 && batch > 1) { if (all1 && batch > 1) {
for (int b = 0; b < batch; b++) { for (int b = 0; b < batch; b++) {
qs[b] = (&curQs[b]); qs[b] = (&curQs[b]);
keys[b] = (pastKeyValues[b * block_cnt + i].first); keys[b] = (pastKeyValues[b * block_cnt + i].first);
attns[b] = (&attnProbs[b]); values[b] = (pastKeyValues[b * block_cnt + i].second);
masks[b] = attentionMask[b];
contexts[b] = (&curContextLayer[b]);
outputSizes[b] = {1, qs[b]->dims[0], qs[b]->dims[1], keys[b]->dims[1]};
} }
MatMulTransBBatch(qs, keys, attns, 1.0 / (scale_attn * (i + 1))); AttentionBatch(qs, keys, values, masks, contexts, qs[0]->dims[0] / values[0]->dims[0], 1.0 / scale_attn, 1);
} else { } else {
for (int b = 0; b < batch; b++) { for (int b = 0; b < batch; b++) {
auto &q = curQs[b]; auto &q = curQs[b];
Data &pastKey = *pastKeyValues[b * block_cnt + i].first; Data &pastKey = *pastKeyValues[b * block_cnt + i].first;
MatMulTransB(q, pastKey, attnProbs[b], 1.0 / (scale_attn * (i + 1))); outputSizes[b] = {1, q.dims[0], q.dims[1], pastKey.dims[1]};
q.Reshape({pastKey.dims[0], -1, q.dims[2]});
} }
}
for (int b = 0; b < batch; b++) { // 1.2 Attention
attnProbs[b].Reshape(outputSizes[b]); // 1.2.0 q * k^T
// 1.2.1 Mask if (all1 && batch > 1) {
if (attentionMask[b] != nullptr) { for (int b = 0; b < batch; b++) {
AttentionMask(attnProbs[b], *attentionMask[b], -10000); qs[b] = (&curQs[b]);
keys[b] = (pastKeyValues[b * block_cnt + i].first);
attns[b] = (&attnProbs[b]);
}
MatMulTransBBatch(qs, keys, attns, 1.0 / (scale_attn * (i + 1)));
} else {
for (int b = 0; b < batch; b++) {
auto &q = curQs[b];
Data &pastKey = *pastKeyValues[b * block_cnt + i].first;
MatMulTransB(q, pastKey, attnProbs[b], 1.0 / (scale_attn * (i + 1)));
}
} }
}
// 1.2.2 softmax
for (int i = 0; i < attnProbs.size(); i++) {
attns[i] = (&attnProbs[i]);
}
MulBatch(attns, i + 1, attns);
SoftmaxBatch(attns, attns, -1);
for (int b = 0; b < batch; b++) {
Data &pastValue = *pastKeyValues[b * block_cnt + i].second;
outputSizes[b] = {1, num_attention_heads, -1, pastValue.dims[2]};
attnProbs[b].Reshape({pastValue.dims[0], -1, attnProbs[b].dims[3]});
}
// 1.2.3 prob * v
if (all1 && batch > 1) {
for (int b = 0; b < batch; b++) { for (int b = 0; b < batch; b++) {
attns[b] = (&attnProbs[b]); attnProbs[b].Reshape(outputSizes[b]);
values[b] = (pastKeyValues[b * block_cnt + i].second); // 1.2.1 Mask
contexts[b] = (&curContextLayer[b]); if (attentionMask[b] != nullptr) {
AttentionMask(attnProbs[b], *attentionMask[b], -10000);
}
} }
MatMulBatch(attns, values, contexts);
} else { // 1.2.2 softmax
for (int i = 0; i < attnProbs.size(); i++) {
attns[i] = (&attnProbs[i]);
}
MulBatch(attns, i + 1, attns);
SoftmaxBatch(attns, attns, -1);
for (int b = 0; b < batch; b++) { for (int b = 0; b < batch; b++) {
Data &pastValue = *pastKeyValues[b * block_cnt + i].second; Data &pastValue = *pastKeyValues[b * block_cnt + i].second;
MatMul(attnProbs[b], pastValue, curContextLayer[b]); outputSizes[b] = {1, num_attention_heads, -1, pastValue.dims[2]};
attnProbs[b].Reshape({pastValue.dims[0], -1, attnProbs[b].dims[3]});
}
// 1.2.3 prob * v
if (all1 && batch > 1) {
for (int b = 0; b < batch; b++) {
attns[b] = (&attnProbs[b]);
values[b] = (pastKeyValues[b * block_cnt + i].second);
contexts[b] = (&curContextLayer[b]);
}
MatMulBatch(attns, values, contexts);
} else {
for (int b = 0; b < batch; b++) {
Data &pastValue = *pastKeyValues[b * block_cnt + i].second;
MatMul(attnProbs[b], pastValue, curContextLayer[b]);
}
} }
} }
if (all1) { if (all1) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment