Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fastllm
Commits
44be91d3
Commit
44be91d3
authored
Oct 14, 2023
by
zhouxiang
Browse files
同步新版特性,解决qwen持续输出问题等
parent
aefd9f11
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
805 additions
and
133 deletions
+805
-133
CMakeLists.txt
CMakeLists.txt
+5
-2
docs/fastllm_pytools.md
docs/fastllm_pytools.md
+0
-0
example/apiserver/apiserver.cpp
example/apiserver/apiserver.cpp
+57
-8
include/devices/cpu/cpudevice.h
include/devices/cpu/cpudevice.h
+10
-0
include/devices/cuda/cudadevice.h
include/devices/cuda/cudadevice.h
+10
-0
include/devices/cuda/fastllm-cuda.cuh
include/devices/cuda/fastllm-cuda.cuh
+4
-0
include/fastllm.h
include/fastllm.h
+10
-0
include/models/basellm.h
include/models/basellm.h
+2
-0
include/utils/utils.h
include/utils/utils.h
+6
-0
src/devices/cpu/cpudevice.cpp
src/devices/cpu/cpudevice.cpp
+56
-31
src/devices/cpu/cpudevicebatch.cpp
src/devices/cpu/cpudevicebatch.cpp
+41
-0
src/devices/cuda/cudadevice.cpp
src/devices/cuda/cudadevice.cpp
+29
-1
src/devices/cuda/cudadevicebatch.cpp
src/devices/cuda/cudadevicebatch.cpp
+42
-0
src/devices/cuda/fastllm-cuda.cu
src/devices/cuda/fastllm-cuda.cu
+346
-1
src/executor.cpp
src/executor.cpp
+8
-4
src/fastllm.cpp
src/fastllm.cpp
+83
-40
src/models/basellm.cpp
src/models/basellm.cpp
+35
-1
src/models/chatglm.cpp
src/models/chatglm.cpp
+61
-45
No files found.
CMakeLists.txt
View file @
44be91d3
...
@@ -18,7 +18,7 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
...
@@ -18,7 +18,7 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
elseif
(
CMAKE_CXX_COMPILER_ID STREQUAL
"MSVC"
)
elseif
(
CMAKE_CXX_COMPILER_ID STREQUAL
"MSVC"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-DNOMINMAX -O2 /std:c++17 /arch:AVX /source-charset:utf-8"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-DNOMINMAX -O2 /std:c++17 /arch:AVX /source-charset:utf-8"
)
else
()
else
()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-pthread --std=c++17 -O2 -
g
"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-pthread --std=c++17 -O2 -
march=native
"
)
endif
()
endif
()
...
@@ -42,8 +42,9 @@ if (USE_CUDA)
...
@@ -42,8 +42,9 @@ if (USE_CUDA)
#message(${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES})
#message(${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES})
set
(
FASTLLM_CUDA_SOURCES src/devices/cuda/cudadevice.cpp src/devices/cuda/cudadevicebatch.cpp src/devices/cuda/fastllm-cuda.cu
)
set
(
FASTLLM_CUDA_SOURCES src/devices/cuda/cudadevice.cpp src/devices/cuda/cudadevicebatch.cpp src/devices/cuda/fastllm-cuda.cu
)
set
(
FASTLLM_LINKED_LIBS
${
FASTLLM_LINKED_LIBS
}
cublas
)
set
(
FASTLLM_LINKED_LIBS
${
FASTLLM_LINKED_LIBS
}
cublas
)
set
(
CMAKE_CUDA_ARCHITECTURES
"native"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-g --gpu-max-threads-per-block=1024"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-g --gpu-max-threads-per-block=1024"
)
#set(CMAKE_CUDA_ARCHITECTURES "70")
endif
()
endif
()
if
(
PY_API
)
if
(
PY_API
)
...
@@ -84,6 +85,8 @@ add_custom_command(
...
@@ -84,6 +85,8 @@ add_custom_command(
add_executable
(
benchmark example/benchmark/benchmark.cpp
)
add_executable
(
benchmark example/benchmark/benchmark.cpp
)
target_link_libraries
(
benchmark fastllm
)
target_link_libraries
(
benchmark fastllm
)
add_executable
(
apiserver example/apiserver/apiserver.cpp example/apiserver/json11.cpp
)
target_link_libraries
(
apiserver fastllm
)
add_library
(
fastllm_tools SHARED
${
FASTLLM_CXX_SOURCES
}
${
FASTLLM_CUDA_SOURCES
}
tools/src/pytools.cpp
)
add_library
(
fastllm_tools SHARED
${
FASTLLM_CXX_SOURCES
}
${
FASTLLM_CUDA_SOURCES
}
tools/src/pytools.cpp
)
target_link_libraries
(
fastllm_tools PUBLIC
${
FASTLLM_LINKED_LIBS
}
)
target_link_libraries
(
fastllm_tools PUBLIC
${
FASTLLM_LINKED_LIBS
}
)
...
...
docs/fastllm_pytools.md
0 → 100644
View file @
44be91d3
example/apiserver/apiserver.cpp
View file @
44be91d3
...
@@ -128,6 +128,8 @@ struct APIConfig {
...
@@ -128,6 +128,8 @@ struct APIConfig {
int
threads
=
4
;
// 使用的线程数
int
threads
=
4
;
// 使用的线程数
bool
lowMemMode
=
false
;
// 是否使用低内存模式
bool
lowMemMode
=
false
;
// 是否使用低内存模式
int
port
=
8080
;
// 端口号
int
port
=
8080
;
// 端口号
int
tokens
=
-
1
;
// token容量限制
int
batch
=
256
;
// batch数限制
};
};
void
ToNext
(
char
*
&
cur
,
const
std
::
string
&
target
,
std
::
string
&
v
)
{
void
ToNext
(
char
*
&
cur
,
const
std
::
string
&
target
,
std
::
string
&
v
)
{
...
@@ -178,13 +180,40 @@ struct HttpRequest {
...
@@ -178,13 +180,40 @@ struct HttpRequest {
}
}
}
}
bool
IsValid
(
char
*
buffer
,
int
size
)
{
char
*
old
=
buffer
;
headers
.
clear
();
ToNext
(
buffer
,
" "
,
method
);
ToNext
(
buffer
,
" "
,
route
);
ToNext
(
buffer
,
"
\r\n
"
,
type
);
while
(
true
)
{
if
(
buffer
[
0
]
==
0
||
((
long
long
)(
buffer
-
old
))
>
1024
*
1024
)
{
break
;
}
if
(
buffer
[
0
]
==
'\r'
&&
buffer
[
1
]
==
'\n'
)
{
if
(
headers
.
find
(
"Content-Length"
)
!=
headers
.
end
())
{
if
(
size
-
((
long
long
)(
buffer
-
old
))
-
2
>=
atoi
(
headers
[
"Content-Length"
].
c_str
()))
{
return
true
;
}
else
{
return
false
;
}
}
}
else
{
std
::
string
key
;
ToNext
(
buffer
,
":"
,
key
);
ToNext
(
buffer
,
"
\r\n
"
,
headers
[
key
]);
}
}
return
false
;
}
void
Print
()
{
void
Print
()
{
for
(
auto
&
it
:
headers
)
{
for
(
auto
&
it
:
headers
)
{
printf
(
"%s: %s
\n
"
,
it
.
first
.
c_str
(),
it
.
second
.
c_str
());
printf
(
"%s: %s
\n
"
,
it
.
first
.
c_str
(),
it
.
second
.
c_str
());
}
}
printf
(
"body: %s
\n
"
,
body
.
c_str
());
printf
(
"body: %s
\n
"
,
body
.
c_str
());
}
}
};
}
httpChecker
;
struct
WorkNode
{
struct
WorkNode
{
int
client
;
int
client
;
...
@@ -201,7 +230,7 @@ struct WorkNode {
...
@@ -201,7 +230,7 @@ struct WorkNode {
struct
WorkQueue
{
struct
WorkQueue
{
std
::
unique_ptr
<
fastllm
::
basellm
>
model
;
std
::
unique_ptr
<
fastllm
::
basellm
>
model
;
int
maxActivateQueryNumber
=
128
;
int
maxActivateQueryNumber
=
256
;
int
activateQueryNumber
=
0
;
int
activateQueryNumber
=
0
;
int
totalQueryNumber
=
0
;
int
totalQueryNumber
=
0
;
std
::
mutex
locker
;
std
::
mutex
locker
;
...
@@ -234,10 +263,12 @@ struct WorkQueue {
...
@@ -234,10 +263,12 @@ struct WorkQueue {
WorkNode
*
now
=
ts
->
q
.
front
();
WorkNode
*
now
=
ts
->
q
.
front
();
ts
->
q
.
pop
();
ts
->
q
.
pop
();
ts
->
activateQueryNumber
++
;
ts
->
activateQueryNumber
++
;
//ts->totalQueryNumber++;
//printf("totalQueryNumber = %d\n", ts->totalQueryNumber);
ts
->
totalQueryNumber
++
;
printf
(
"totalQueryNumber = %d
\n
"
,
ts
->
totalQueryNumber
);
//printf("activate = %d, q.size() = %d\n", ts->activateQueryNumber, (int) ts->q.size());
//printf("activate = %d, q.size() = %d\n", ts->activateQueryNumber, (int) ts->q.size());
new
std
::
thread
([](
WorkQueue
*
ts
,
WorkNode
*
now
)
{
std
::
thread
*
t
=
new
std
::
thread
([](
WorkQueue
*
ts
,
WorkNode
*
now
)
{
ts
->
Deal
(
now
);
ts
->
Deal
(
now
);
printf
(
"Response client %d finish
\n
"
,
now
->
client
);
printf
(
"Response client %d finish
\n
"
,
now
->
client
);
ts
->
locker
.
lock
();
ts
->
locker
.
lock
();
...
@@ -310,11 +341,13 @@ void Usage() {
...
@@ -310,11 +341,13 @@ void Usage() {
std
::
cout
<<
"<-w|--web> <args>: 网页文件的路径"
<<
std
::
endl
;
std
::
cout
<<
"<-w|--web> <args>: 网页文件的路径"
<<
std
::
endl
;
std
::
cout
<<
"<-t|--threads> <args>: 使用的线程数量"
<<
std
::
endl
;
std
::
cout
<<
"<-t|--threads> <args>: 使用的线程数量"
<<
std
::
endl
;
std
::
cout
<<
"<-l|--low>: 使用低内存模式"
<<
std
::
endl
;
std
::
cout
<<
"<-l|--low>: 使用低内存模式"
<<
std
::
endl
;
std
::
cout
<<
"<--batch>: 最大batch数"
<<
std
::
endl
;
std
::
cout
<<
"<--tokens>: 最大tokens容量"
<<
std
::
endl
;
std
::
cout
<<
"<--port> <args>: 网页端口号"
<<
std
::
endl
;
std
::
cout
<<
"<--port> <args>: 网页端口号"
<<
std
::
endl
;
}
}
void
ParseArgs
(
int
argc
,
char
**
argv
,
APIConfig
&
config
)
{
void
ParseArgs
(
int
argc
,
char
**
argv
,
APIConfig
&
config
)
{
std
::
vector
<
std
::
string
>
sargv
;
std
::
vector
<
std
::
string
>
sargv
;
for
(
int
i
=
0
;
i
<
argc
;
i
++
)
{
for
(
int
i
=
0
;
i
<
argc
;
i
++
)
{
sargv
.
push_back
(
std
::
string
(
argv
[
i
]));
sargv
.
push_back
(
std
::
string
(
argv
[
i
]));
}
}
...
@@ -332,6 +365,10 @@ void ParseArgs(int argc, char **argv, APIConfig &config) {
...
@@ -332,6 +365,10 @@ void ParseArgs(int argc, char **argv, APIConfig &config) {
config
.
webPath
=
sargv
[
++
i
];
config
.
webPath
=
sargv
[
++
i
];
}
else
if
(
sargv
[
i
]
==
"--port"
)
{
}
else
if
(
sargv
[
i
]
==
"--port"
)
{
config
.
port
=
atoi
(
sargv
[
++
i
].
c_str
());
config
.
port
=
atoi
(
sargv
[
++
i
].
c_str
());
}
else
if
(
sargv
[
i
]
==
"--tokens"
)
{
config
.
tokens
=
atoi
(
sargv
[
++
i
].
c_str
());
}
else
if
(
sargv
[
i
]
==
"--batch"
)
{
config
.
batch
=
atoi
(
sargv
[
++
i
].
c_str
());
}
else
{
}
else
{
Usage
();
Usage
();
exit
(
-
1
);
exit
(
-
1
);
...
@@ -350,6 +387,8 @@ int main(int argc, char** argv) {
...
@@ -350,6 +387,8 @@ int main(int argc, char** argv) {
fastllm
::
SetThreads
(
config
.
threads
);
fastllm
::
SetThreads
(
config
.
threads
);
fastllm
::
SetLowMemMode
(
config
.
lowMemMode
);
fastllm
::
SetLowMemMode
(
config
.
lowMemMode
);
workQueue
.
model
=
fastllm
::
CreateLLMModelFromFile
(
config
.
path
);
workQueue
.
model
=
fastllm
::
CreateLLMModelFromFile
(
config
.
path
);
workQueue
.
model
->
tokensLimit
=
config
.
tokens
;
workQueue
.
maxActivateQueryNumber
=
std
::
max
(
1
,
std
::
min
(
256
,
config
.
batch
));
workQueue
.
Start
();
workQueue
.
Start
();
int
local_fd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
int
local_fd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
...
@@ -375,7 +414,6 @@ int main(int argc, char** argv) {
...
@@ -375,7 +414,6 @@ int main(int argc, char** argv) {
listen
(
local_fd
,
2000
);
listen
(
local_fd
,
2000
);
printf
(
"start...
\n
"
);
printf
(
"start...
\n
"
);
int
queuePos
=
0
;
int
queuePos
=
0
;
while
(
true
)
{
//循环接收客户端的请求
while
(
true
)
{
//循环接收客户端的请求
//5.创建一个sockaddr_in结构体,用来存储客户机的地址
//5.创建一个sockaddr_in结构体,用来存储客户机的地址
struct
sockaddr_in
client_addr
;
struct
sockaddr_in
client_addr
;
...
@@ -386,8 +424,19 @@ int main(int argc, char** argv) {
...
@@ -386,8 +424,19 @@ int main(int argc, char** argv) {
exit
(
-
1
);
exit
(
-
1
);
}
}
int
size
=
read
(
client
,
buff
,
sizeof
(
buff
));
int
size
=
0
;
while
(
true
)
{
int
cur
=
read
(
client
,
buff
+
size
,
sizeof
(
buff
)
-
size
);
size
+=
cur
;
if
(
httpChecker
.
IsValid
(
buff
,
size
))
{
break
;
}
}
buff
[
size
]
=
0
;
buff
[
size
]
=
0
;
while
(
workQueue
.
q
.
size
()
>
workQueue
.
maxActivateQueryNumber
)
{
sleep
(
0
);
}
workQueue
.
Push
(
buff
,
client
);
workQueue
.
Push
(
buff
,
client
);
}
}
...
...
include/devices/cpu/cpudevice.h
View file @
44be91d3
...
@@ -149,6 +149,11 @@ namespace fastllm {
...
@@ -149,6 +149,11 @@ namespace fastllm {
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
};
};
class
CpuCopyKVCacheOp
:
BaseOperator
{
void
Reshape
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
};
class
CpuSplitBatchOp
:
BaseBatchOperator
{
class
CpuSplitBatchOp
:
BaseBatchOperator
{
void
Reshape
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Reshape
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
...
@@ -180,6 +185,11 @@ namespace fastllm {
...
@@ -180,6 +185,11 @@ namespace fastllm {
class
CpuCatDirectBatchOp
:
BaseBatchOperator
{
class
CpuCatDirectBatchOp
:
BaseBatchOperator
{
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
};
};
class
CpuAttentionBatchOp
:
BaseBatchOperator
{
void
Reshape
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
};
}
}
#endif //FASTLLM_CPUDEVICE_H
#endif //FASTLLM_CPUDEVICE_H
include/devices/cuda/cudadevice.h
View file @
44be91d3
...
@@ -24,6 +24,11 @@ namespace fastllm {
...
@@ -24,6 +24,11 @@ namespace fastllm {
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
};
};
class
CudaCopyKVCacheOp
:
BaseOperator
{
void
Reshape
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
};
class
CudaLayerNormOp
:
BaseOperator
{
class
CudaLayerNormOp
:
BaseOperator
{
bool
CanRun
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
bool
CanRun
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
...
@@ -154,6 +159,11 @@ namespace fastllm {
...
@@ -154,6 +159,11 @@ namespace fastllm {
class
CudaCatDirectBatchOp
:
BaseBatchOperator
{
class
CudaCatDirectBatchOp
:
BaseBatchOperator
{
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
};
};
class
CudaAttentionBatchOp
:
BaseBatchOperator
{
void
Reshape
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
void
Run
(
const
std
::
string
&
opType
,
const
DataDict
&
datas
,
const
FloatDict
&
floatParams
,
const
IntDict
&
intParams
);
};
}
}
#endif //FASTLLM_CUDADEVICE_H
#endif //FASTLLM_CUDADEVICE_H
include/devices/cuda/fastllm-cuda.cuh
View file @
44be91d3
...
@@ -9,6 +9,8 @@ void FastllmCudaMallocBigBuffer(size_t size);
...
@@ -9,6 +9,8 @@ void FastllmCudaMallocBigBuffer(size_t size);
void
FastllmCudaClearBigBuffer
();
void
FastllmCudaClearBigBuffer
();
void
*
FastllmCudaMalloc
(
size_t
size
);
void
*
FastllmCudaMalloc
(
size_t
size
);
void
FastllmCudaFree
(
void
*
ret
);
void
FastllmCudaFree
(
void
*
ret
);
void
*
FastllmCudaDirectMalloc
(
size_t
size
);
void
FastllmCudaDirectFree
(
void
*
ret
);
void
FastllmCudaCopyFromHostToDevice
(
void
*
dst
,
void
*
src
,
size_t
size
);
void
FastllmCudaCopyFromHostToDevice
(
void
*
dst
,
void
*
src
,
size_t
size
);
void
FastllmCudaCopyFromDeviceToHost
(
void
*
dst
,
void
*
src
,
size_t
size
);
void
FastllmCudaCopyFromDeviceToHost
(
void
*
dst
,
void
*
src
,
size_t
size
);
...
@@ -55,6 +57,8 @@ bool FastllmCudaLlamaRotatePosition2D(fastllm::Data &data, const fastllm::Data &
...
@@ -55,6 +57,8 @@ bool FastllmCudaLlamaRotatePosition2D(fastllm::Data &data, const fastllm::Data &
const
fastllm
::
Data
&
sinData
,
const
fastllm
::
Data
&
cosData
,
int
rotaryDim
);
const
fastllm
::
Data
&
sinData
,
const
fastllm
::
Data
&
cosData
,
int
rotaryDim
);
bool
FastllmCudaApplyLognAttn
(
fastllm
::
Data
&
input
,
fastllm
::
Data
&
lognAttn
,
fastllm
::
Data
&
positionIds
);
bool
FastllmCudaApplyLognAttn
(
fastllm
::
Data
&
input
,
fastllm
::
Data
&
lognAttn
,
fastllm
::
Data
&
positionIds
);
bool
FastllmCudaAttentionBatch
(
fastllm
::
Data
**
q
,
fastllm
::
Data
**
k
,
fastllm
::
Data
**
v
,
fastllm
::
Data
**
mask
,
fastllm
::
Data
**
output
,
int
group
,
float
scale
,
int
batch
);
bool
FastllmCudaSplitBatch
(
fastllm
::
Data
&
input
,
fastllm
::
Data
**
outputs
,
int
axis
);
bool
FastllmCudaSplitBatch
(
fastllm
::
Data
&
input
,
fastllm
::
Data
**
outputs
,
int
axis
);
bool
FastllmCudaCatBatch
(
fastllm
::
Data
**
inputs
,
fastllm
::
Data
&
output
,
int
axis
);
bool
FastllmCudaCatBatch
(
fastllm
::
Data
**
inputs
,
fastllm
::
Data
&
output
,
int
axis
);
bool
FastllmCudaMulBatch
(
fastllm
::
Data
**
inputs
,
float
v
,
int
batch
,
fastllm
::
Data
**
outputs
);
bool
FastllmCudaMulBatch
(
fastllm
::
Data
**
inputs
,
float
v
,
int
batch
,
fastllm
::
Data
**
outputs
);
...
...
include/fastllm.h
View file @
44be91d3
...
@@ -247,6 +247,8 @@ namespace fastllm {
...
@@ -247,6 +247,8 @@ namespace fastllm {
long
long
filePos
;
long
long
filePos
;
std
::
shared_ptr
<
FileMmap
>
m_file
;
std
::
shared_ptr
<
FileMmap
>
m_file
;
bool
directMemory
=
false
;
// 直接分配/释放Memory,不经过缓存
Data
()
{};
Data
()
{};
Data
(
DataType
type
);
Data
(
DataType
type
);
...
@@ -364,6 +366,8 @@ namespace fastllm {
...
@@ -364,6 +366,8 @@ namespace fastllm {
void
TryMergePairs
(
std
::
vector
<
Symbol
>
&
symbols
,
int
l
,
int
r
,
std
::
priority_queue
<
SymbolPairs
>
&
q
);
// 插入备选symbol
void
TryMergePairs
(
std
::
vector
<
Symbol
>
&
symbols
,
int
l
,
int
r
,
std
::
priority_queue
<
SymbolPairs
>
&
q
);
// 插入备选symbol
int
GetRank
(
std
::
vector
<
Symbol
>
&
symbols
,
std
::
vector
<
std
::
pair
<
int
,
int
>>
&
partitions
,
int
idx
,
int
skip
);
void
Insert
(
const
std
::
string
&
s
,
int
tokenId
,
float
score
=
1.0
f
);
// 插入一个token
void
Insert
(
const
std
::
string
&
s
,
int
tokenId
,
float
score
=
1.0
f
);
// 插入一个token
Data
Encode
(
const
std
::
string
&
s
);
// 编码
Data
Encode
(
const
std
::
string
&
s
);
// 编码
...
@@ -418,9 +422,15 @@ namespace fastllm {
...
@@ -418,9 +422,15 @@ namespace fastllm {
void
ToDataType
(
const
Data
&
input
,
DataType
dataType
);
void
ToDataType
(
const
Data
&
input
,
DataType
dataType
);
void
CopyKVCache
(
Data
&
oldCache
,
Data
&
newCache
,
int
oldBsStart
,
int
newBsStart
,
int
bs
,
int
offset
);
void
Attention
(
const
Data
&
q
,
const
Data
&
k
,
const
Data
&
v
,
const
Data
&
mask
,
Data
&
output
,
void
Attention
(
const
Data
&
q
,
const
Data
&
k
,
const
Data
&
v
,
const
Data
&
mask
,
Data
&
output
,
int
group
,
float
scale
,
int
attentionType
);
int
group
,
float
scale
,
int
attentionType
);
void
AttentionBatch
(
std
::
vector
<
Data
*>
&
q
,
std
::
vector
<
Data
*>
&
k
,
std
::
vector
<
Data
*>
&
v
,
std
::
vector
<
Data
*>
&
mask
,
std
::
vector
<
Data
*>
&
output
,
int
group
,
float
scale
,
int
attentionType
);
void
Embedding
(
const
Data
&
input
,
Data
&
weight
,
Data
&
output
);
void
Embedding
(
const
Data
&
input
,
Data
&
weight
,
Data
&
output
);
void
RMSNorm
(
const
Data
&
input
,
const
Data
&
weight
,
float
eps
,
Data
&
output
);
void
RMSNorm
(
const
Data
&
input
,
const
Data
&
weight
,
float
eps
,
Data
&
output
);
...
...
include/models/basellm.h
View file @
44be91d3
...
@@ -152,5 +152,7 @@ namespace fastllm {
...
@@ -152,5 +152,7 @@ namespace fastllm {
std
::
map
<
std
::
string
,
int
>
deviceMap
;
std
::
map
<
std
::
string
,
int
>
deviceMap
;
std
::
string
adapterName
;
std
::
string
adapterName
;
int
tokensLimit
=
-
1
;
};
};
}
}
include/utils/utils.h
View file @
44be91d3
...
@@ -21,6 +21,12 @@
...
@@ -21,6 +21,12 @@
#ifdef __AVX__
#ifdef __AVX__
#include "immintrin.h"
#include "immintrin.h"
#ifdef __GNUC__
#if __GNUC__ < 8
#define _mm256_set_m128i(
/* __m128i */
hi,
/* __m128i */
lo) \
_mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 0x1)
#endif
#endif
#endif
#endif
namespace
fastllm
{
namespace
fastllm
{
...
...
src/devices/cpu/cpudevice.cpp
View file @
44be91d3
...
@@ -23,6 +23,7 @@ namespace fastllm {
...
@@ -23,6 +23,7 @@ namespace fastllm {
this
->
ops
[
"ToFloat16"
]
=
(
BaseOperator
*
)(
new
CpuToFloat16
());
this
->
ops
[
"ToFloat16"
]
=
(
BaseOperator
*
)(
new
CpuToFloat16
());
this
->
ops
[
"ToFloat32"
]
=
(
BaseOperator
*
)(
new
CpuToFloat32
());
this
->
ops
[
"ToFloat32"
]
=
(
BaseOperator
*
)(
new
CpuToFloat32
());
this
->
ops
[
"Attention"
]
=
(
BaseOperator
*
)(
new
CpuAttention
());
this
->
ops
[
"Attention"
]
=
(
BaseOperator
*
)(
new
CpuAttention
());
this
->
ops
[
"CopyKVCache"
]
=
(
BaseOperator
*
)(
new
CpuCopyKVCacheOp
());
this
->
ops
[
"Embedding"
]
=
(
BaseOperator
*
)(
new
CpuEmbedding
());
this
->
ops
[
"Embedding"
]
=
(
BaseOperator
*
)(
new
CpuEmbedding
());
this
->
ops
[
"LayerNorm"
]
=
(
BaseOperator
*
)(
new
CpuLayerNormOp
());
this
->
ops
[
"LayerNorm"
]
=
(
BaseOperator
*
)(
new
CpuLayerNormOp
());
this
->
ops
[
"RMSNorm"
]
=
(
BaseOperator
*
)(
new
CpuRMSNormOp
());
this
->
ops
[
"RMSNorm"
]
=
(
BaseOperator
*
)(
new
CpuRMSNormOp
());
...
@@ -57,6 +58,7 @@ namespace fastllm {
...
@@ -57,6 +58,7 @@ namespace fastllm {
this
->
ops
[
"MatMulTransBBatch"
]
=
(
BaseOperator
*
)(
new
CpuMatMulTransBBatchOp
());
this
->
ops
[
"MatMulTransBBatch"
]
=
(
BaseOperator
*
)(
new
CpuMatMulTransBBatchOp
());
this
->
ops
[
"SoftMaxBatch"
]
=
(
BaseOperator
*
)(
new
CpuSoftmaxBatchOp
());
this
->
ops
[
"SoftMaxBatch"
]
=
(
BaseOperator
*
)(
new
CpuSoftmaxBatchOp
());
this
->
ops
[
"CatDirectBatch"
]
=
(
BaseOperator
*
)(
new
CpuCatDirectBatchOp
());
this
->
ops
[
"CatDirectBatch"
]
=
(
BaseOperator
*
)(
new
CpuCatDirectBatchOp
());
this
->
ops
[
"AttentionBatch"
]
=
(
BaseOperator
*
)(
new
CpuAttentionBatchOp
());
}
}
bool
CpuDevice
::
Malloc
(
void
**
ret
,
size_t
size
)
{
bool
CpuDevice
::
Malloc
(
void
**
ret
,
size_t
size
)
{
...
@@ -77,7 +79,7 @@ namespace fastllm {
...
@@ -77,7 +79,7 @@ namespace fastllm {
return
true
;
return
true
;
}
}
#ifdef __AVX__
#ifdef __AVX2__
#ifdef __AVX2__
int
DotU8U8
(
uint8_t
*
a
,
uint8_t
*
b
,
int
n
)
{
int
DotU8U8
(
uint8_t
*
a
,
uint8_t
*
b
,
int
n
)
{
__m256i
acc
=
_mm256_setzero_si256
();
__m256i
acc
=
_mm256_setzero_si256
();
...
@@ -105,32 +107,31 @@ namespace fastllm {
...
@@ -105,32 +107,31 @@ namespace fastllm {
return
ans
+
I32sum
(
acc
);
return
ans
+
I32sum
(
acc
);
};
};
#else
//#else
int
DotU8U8
(
uint8_t
*
a
,
uint8_t
*
b
,
int
n
)
{
// int DotU8U8(uint8_t *a, uint8_t *b, int n) {
__m256i
acc
=
_mm256_setzero_si256
();
// __m256i acc = _mm256_setzero_si256();
int
i
=
0
;
// int i = 0;
int
ans
=
0
;
// int ans = 0;
for
(;
i
+
31
<
n
;
i
+=
32
)
{
// for (; i + 31 < n; i += 32) {
__m256i
bx
=
_mm256_loadu_si256
((
const
__m256i
*
)
(
a
+
i
));
// __m256i bx = _mm256_loadu_si256((const __m256i *) (a + i));
__m256i
by
=
_mm256_loadu_si256
((
const
__m256i
*
)
(
b
+
i
));
// __m256i by = _mm256_loadu_si256((const __m256i *) (b + i));
__m256i
mx0
=
_mm256_cvtepu8_epi16
(
_mm256_extractf128_si256
(
bx
,
0
));
// __m256i mx0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 0));
__m256i
mx1
=
_mm256_cvtepu8_epi16
(
_mm256_extractf128_si256
(
bx
,
1
));
// __m256i mx1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(bx, 1));
__m256i
my0
=
_mm256_cvtepu8_epi16
(
_mm256_extractf128_si256
(
by
,
0
));
// __m256i my0 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 0));
__m256i
my1
=
_mm256_cvtepu8_epi16
(
_mm256_extractf128_si256
(
by
,
1
));
// __m256i my1 = _mm256_cvtepu8_epi16(_mm256_extractf128_si256(by, 1));
acc
=
_mm256_add_epi32
(
acc
,
_mm256_madd_epi16
(
mx0
,
my0
));
// acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx0, my0));
acc
=
_mm256_add_epi32
(
acc
,
_mm256_madd_epi16
(
mx1
,
my1
));
// //acc = _mm256_add_epi32(acc, _mm256_madd_epi16(mx1, my1));
}
// }
for
(;
i
<
n
;
i
++
)
{
// for (; i < n; i++) {
ans
+=
a
[
i
]
*
b
[
i
];
// ans += a[i] * b[i];
}
// }
return
ans
+
I32sum
(
acc
);
// return ans + I32sum(acc);
};
// };
#endif
int
DotU4U8
(
uint8_t
*
a
,
uint8_t
*
b
,
int
n
)
{
int
DotU4U8
(
uint8_t
*
a
,
uint8_t
*
b
,
int
n
)
{
__m256i
acc
=
_mm256_setzero_si256
();
__m256i
acc
=
_mm256_setzero_si256
();
...
@@ -280,7 +281,7 @@ namespace fastllm {
...
@@ -280,7 +281,7 @@ namespace fastllm {
float
*
qd
=
(
float
*
)
q
.
cpuData
;
float
*
qd
=
(
float
*
)
q
.
cpuData
;
float
*
kd
=
(
float
*
)
k
.
cpuData
;
float
*
kd
=
(
float
*
)
k
.
cpuData
;
float
*
vd
=
(
float
*
)
v
.
cpuData
;
float
*
vd
=
(
float
*
)
v
.
cpuData
;
float
*
maskd
=
mask
.
dims
.
size
()
>
0
?
(
float
*
)
mask
.
cpuData
:
nullptr
;
float
*
maskd
=
(
datas
.
find
(
"mask"
)
->
second
&&
mask
.
dims
.
size
()
>
0
)
?
(
float
*
)
mask
.
cpuData
:
nullptr
;
float
*
od
=
(
float
*
)
output
.
cpuData
;
float
*
od
=
(
float
*
)
output
.
cpuData
;
std
::
fill
(
od
,
od
+
output
.
Count
(
0
),
0.0
f
);
std
::
fill
(
od
,
od
+
output
.
Count
(
0
),
0.0
f
);
auto
pool
=
GetPool
();
auto
pool
=
GetPool
();
...
@@ -296,6 +297,30 @@ namespace fastllm {
...
@@ -296,6 +297,30 @@ namespace fastllm {
}
}
}
}
void
CpuCopyKVCacheOp
::
Reshape
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
return
;
}
void
CpuCopyKVCacheOp
::
Run
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
Data
&
oldCache
=
*
(
datas
.
find
(
"oldCache"
)
->
second
);
Data
&
newCache
=
*
(
datas
.
find
(
"newCache"
)
->
second
);
int
oldBsStart
=
intParams
.
find
(
"oldBsStart"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"oldBsStart"
)
->
second
:
-
1
;
int
newBsStart
=
intParams
.
find
(
"newBsStart"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"newBsStart"
)
->
second
:
-
1
;
int
bs
=
intParams
.
find
(
"bs"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"bs"
)
->
second
:
-
1
;
int
offset
=
intParams
.
find
(
"offset"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"offset"
)
->
second
:
-
1
;
int
unitSize
=
oldCache
.
unitSize
;
for
(
int
o
=
0
;
o
<
bs
;
o
++
)
{
uint8_t
*
cur
=
newCache
.
cpuData
+
(
newBsStart
+
o
)
*
newCache
.
strides
[
0
]
*
unitSize
;
cur
+=
offset
*
newCache
.
strides
[
1
]
*
unitSize
;
uint8_t
*
old
=
oldCache
.
cpuData
+
(
oldBsStart
+
o
)
*
oldCache
.
strides
[
0
]
*
unitSize
;
memcpy
(
cur
,
old
,
oldCache
.
dims
[
1
]
*
oldCache
.
dims
[
2
]
*
unitSize
);
}
}
void
CpuEmbedding
::
Reshape
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
void
CpuEmbedding
::
Reshape
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
Data
&
input
=
*
(
datas
.
find
(
"input"
)
->
second
);
Data
&
input
=
*
(
datas
.
find
(
"input"
)
->
second
);
...
@@ -894,7 +919,7 @@ namespace fastllm {
...
@@ -894,7 +919,7 @@ namespace fastllm {
c
[
block
*
kstride
+
i
]
=
value
;
c
[
block
*
kstride
+
i
]
=
value
;
}
}
}
}
#elif defined(__AVX__)
#elif defined(__AVX
2
__)
int
block
=
0
;
int
block
=
0
;
for
(;
block
<
n
;
block
++
)
{
for
(;
block
<
n
;
block
++
)
{
uint8_t
*
weightWalk
=
b
;
uint8_t
*
weightWalk
=
b
;
...
@@ -968,7 +993,7 @@ namespace fastllm {
...
@@ -968,7 +993,7 @@ namespace fastllm {
sum0
=
vpadalq_u16
(
sum0
,
vmull_u8
(
vb
,
in
.
val
[
0
]));
sum0
=
vpadalq_u16
(
sum0
,
vmull_u8
(
vb
,
in
.
val
[
0
]));
}
}
value
+=
sum0
[
0
]
+
sum0
[
1
]
+
sum0
[
2
]
+
sum0
[
3
];
value
+=
sum0
[
0
]
+
sum0
[
1
]
+
sum0
[
2
]
+
sum0
[
3
];
#elif defined(__AVX__)
#elif defined(__AVX
2
__)
value
+=
DotU4U8
(
weightWalk
+
i
*
m
/
2
,
inputWalk
,
m
);
value
+=
DotU4U8
(
weightWalk
+
i
*
m
/
2
,
inputWalk
,
m
);
j
+=
m
;
j
+=
m
;
#endif
#endif
...
@@ -1039,7 +1064,7 @@ namespace fastllm {
...
@@ -1039,7 +1064,7 @@ namespace fastllm {
sum0
=
vpadalq_u16
(
sum0
,
vmull_u8
(
vb
,
in
.
val
[
0
]));
sum0
=
vpadalq_u16
(
sum0
,
vmull_u8
(
vb
,
in
.
val
[
0
]));
}
}
value
+=
sum0
[
0
]
+
sum0
[
1
]
+
sum0
[
2
]
+
sum0
[
3
];
value
+=
sum0
[
0
]
+
sum0
[
1
]
+
sum0
[
2
]
+
sum0
[
3
];
#elif defined(__AVX__)
#elif defined(__AVX
2
__)
value
+=
DotU4U8
(
weightWalk
+
i
*
m
/
2
,
inputWalk
,
m
);
value
+=
DotU4U8
(
weightWalk
+
i
*
m
/
2
,
inputWalk
,
m
);
j
+=
m
;
j
+=
m
;
#endif
#endif
...
...
src/devices/cpu/cpudevicebatch.cpp
View file @
44be91d3
...
@@ -202,4 +202,45 @@ namespace fastllm {
...
@@ -202,4 +202,45 @@ namespace fastllm {
}
}
delete
op
;
delete
op
;
}
}
void
CpuAttentionBatchOp
::
Reshape
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
Data
**
qs
=
(
Data
**
)(
datas
.
find
(
"q"
)
->
second
);
Data
**
ks
=
(
Data
**
)(
datas
.
find
(
"k"
)
->
second
);
Data
**
vs
=
(
Data
**
)(
datas
.
find
(
"v"
)
->
second
);
Data
**
outputs
=
(
Data
**
)(
datas
.
find
(
"output"
)
->
second
);
int
group
=
intParams
.
find
(
"group"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"group"
)
->
second
:
1
;
int
batch
=
intParams
.
find
(
"q___batch"
)
->
second
;
Data
&
q
=
*
qs
[
0
],
&
k
=
*
ks
[
0
],
&
v
=
*
vs
[
0
];
AssertInFastLLM
(
q
.
dims
.
size
()
==
3
&&
k
.
dims
.
size
()
==
3
&&
v
.
dims
.
size
()
==
3
,
"Attention: dims of q, k, v should be 3.
\n
"
);
AssertInFastLLM
(
q
.
dims
[
2
]
==
k
.
dims
[
2
],
"Attention: q.dims[2] should be equal to k.dims[2].
\n
"
);
AssertInFastLLM
(
k
.
dims
[
1
]
==
v
.
dims
[
1
],
"Attention: k.dims[1] should be equal to v.dims[1].
\n
"
);
AssertInFastLLM
(
k
.
dims
[
0
]
==
v
.
dims
[
0
],
"Attention: k.dims[0] should be equal to v.dims[0].
\n
"
);
AssertInFastLLM
(
q
.
dims
[
0
]
==
k
.
dims
[
0
]
*
group
,
"Attention: q.dims[0] should be equal to k.dims[0] * group.
\n
"
);
AssertInFastLLM
(
q
.
dataType
==
k
.
dataType
&&
q
.
dataType
==
v
.
dataType
,
"Attention: q, k, v's datatype should be same.
\n
"
);
AssertInFastLLM
(
q
.
dataType
==
DataType
::
FLOAT32
,
"Attention's input's type should be float32.
\n
"
);
for
(
int
i
=
0
;
i
<
batch
;
i
++
)
{
outputs
[
i
]
->
dataType
=
qs
[
i
]
->
dataType
;
outputs
[
i
]
->
Resize
({
qs
[
i
]
->
dims
[
0
],
qs
[
i
]
->
dims
[
1
],
vs
[
i
]
->
dims
[
2
]});
}
}
void
CpuAttentionBatchOp
::
Run
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
fastllm
::
BaseOperator
*
op
=
(
fastllm
::
BaseOperator
*
)(
new
CpuAttention
());
int
batch
=
intParams
.
find
(
"q___batch"
)
->
second
;
DataDict
tempDatas
=
datas
;
for
(
int
i
=
0
;
i
<
batch
;
i
++
)
{
tempDatas
[
"q"
]
=
((
Data
**
)
datas
.
find
(
"q"
)
->
second
)[
i
];
tempDatas
[
"k"
]
=
((
Data
**
)
datas
.
find
(
"k"
)
->
second
)[
i
];
tempDatas
[
"v"
]
=
((
Data
**
)
datas
.
find
(
"v"
)
->
second
)[
i
];
tempDatas
[
"mask"
]
=
((
Data
**
)
datas
.
find
(
"mask"
)
->
second
)[
i
];
tempDatas
[
"output"
]
=
((
Data
**
)
datas
.
find
(
"output"
)
->
second
)[
i
];
op
->
Run
(
"Attention"
,
tempDatas
,
floatParams
,
intParams
);
}
delete
op
;
}
}
}
\ No newline at end of file
src/devices/cuda/cudadevice.cpp
View file @
44be91d3
...
@@ -13,6 +13,7 @@ namespace fastllm {
...
@@ -13,6 +13,7 @@ namespace fastllm {
CudaDevice
::
CudaDevice
()
{
CudaDevice
::
CudaDevice
()
{
this
->
deviceType
=
"cuda"
;
this
->
deviceType
=
"cuda"
;
this
->
ops
[
"Attention"
]
=
(
BaseOperator
*
)(
new
CudaAttention
());
this
->
ops
[
"Attention"
]
=
(
BaseOperator
*
)(
new
CudaAttention
());
this
->
ops
[
"CopyKVCache"
]
=
(
BaseOperator
*
)(
new
CudaCopyKVCacheOp
());
this
->
ops
[
"LayerNorm"
]
=
(
BaseOperator
*
)(
new
CudaLayerNormOp
());
this
->
ops
[
"LayerNorm"
]
=
(
BaseOperator
*
)(
new
CudaLayerNormOp
());
this
->
ops
[
"RMSNorm"
]
=
(
BaseOperator
*
)(
new
CudaRMSNormOp
());
this
->
ops
[
"RMSNorm"
]
=
(
BaseOperator
*
)(
new
CudaRMSNormOp
());
this
->
ops
[
"Linear"
]
=
(
BaseOperator
*
)(
new
CudaLinearOp
());
this
->
ops
[
"Linear"
]
=
(
BaseOperator
*
)(
new
CudaLinearOp
());
...
@@ -43,6 +44,7 @@ namespace fastllm {
...
@@ -43,6 +44,7 @@ namespace fastllm {
this
->
ops
[
"MatMulTransBBatch"
]
=
(
BaseOperator
*
)(
new
CudaMatMulTransBBatchOp
());
this
->
ops
[
"MatMulTransBBatch"
]
=
(
BaseOperator
*
)(
new
CudaMatMulTransBBatchOp
());
this
->
ops
[
"SoftMaxBatch"
]
=
(
BaseOperator
*
)(
new
CudaSoftmaxBatchOp
());
this
->
ops
[
"SoftMaxBatch"
]
=
(
BaseOperator
*
)(
new
CudaSoftmaxBatchOp
());
this
->
ops
[
"CatDirectBatch"
]
=
(
BaseOperator
*
)(
new
CudaCatDirectBatchOp
());
this
->
ops
[
"CatDirectBatch"
]
=
(
BaseOperator
*
)(
new
CudaCatDirectBatchOp
());
this
->
ops
[
"AttentionBatch"
]
=
(
BaseOperator
*
)(
new
CudaAttentionBatchOp
());
}
}
bool
CudaDevice
::
Malloc
(
void
**
ret
,
size_t
size
)
{
bool
CudaDevice
::
Malloc
(
void
**
ret
,
size_t
size
)
{
...
@@ -90,10 +92,11 @@ namespace fastllm {
...
@@ -90,10 +92,11 @@ namespace fastllm {
void
CudaAttention
::
Run
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
void
CudaAttention
::
Run
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
Data
emptyData
;
Data
&
q
=
*
(
datas
.
find
(
"q"
)
->
second
);
Data
&
q
=
*
(
datas
.
find
(
"q"
)
->
second
);
Data
&
k
=
*
(
datas
.
find
(
"k"
)
->
second
);
Data
&
k
=
*
(
datas
.
find
(
"k"
)
->
second
);
Data
&
v
=
*
(
datas
.
find
(
"v"
)
->
second
);
Data
&
v
=
*
(
datas
.
find
(
"v"
)
->
second
);
Data
&
mask
=
*
(
datas
.
find
(
"mask"
)
->
second
);
Data
&
mask
=
datas
.
find
(
"mask"
)
->
second
?
*
(
datas
.
find
(
"mask"
)
->
second
)
:
emptyData
;
Data
&
output
=
*
(
datas
.
find
(
"output"
)
->
second
);
Data
&
output
=
*
(
datas
.
find
(
"output"
)
->
second
);
int
group
=
intParams
.
find
(
"group"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"group"
)
->
second
:
1
;
int
group
=
intParams
.
find
(
"group"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"group"
)
->
second
:
1
;
float
scale
=
floatParams
.
find
(
"scale"
)
!=
floatParams
.
end
()
?
floatParams
.
find
(
"scale"
)
->
second
:
1.0
;
float
scale
=
floatParams
.
find
(
"scale"
)
!=
floatParams
.
end
()
?
floatParams
.
find
(
"scale"
)
->
second
:
1.0
;
...
@@ -101,6 +104,31 @@ namespace fastllm {
...
@@ -101,6 +104,31 @@ namespace fastllm {
FastllmCudaAttention
(
q
,
k
,
v
,
mask
,
output
,
group
,
scale
);
FastllmCudaAttention
(
q
,
k
,
v
,
mask
,
output
,
group
,
scale
);
}
}
void
CudaCopyKVCacheOp
::
Reshape
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
return
;
}
void
CudaCopyKVCacheOp
::
Run
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
Data
&
oldCache
=
*
(
datas
.
find
(
"oldCache"
)
->
second
);
Data
&
newCache
=
*
(
datas
.
find
(
"newCache"
)
->
second
);
int
oldBsStart
=
intParams
.
find
(
"oldBsStart"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"oldBsStart"
)
->
second
:
-
1
;
int
newBsStart
=
intParams
.
find
(
"newBsStart"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"newBsStart"
)
->
second
:
-
1
;
int
bs
=
intParams
.
find
(
"bs"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"bs"
)
->
second
:
-
1
;
int
offset
=
intParams
.
find
(
"offset"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"offset"
)
->
second
:
-
1
;
int
unitSize
=
oldCache
.
unitSize
;
FastllmCudaMemcpy2DDeviceToDevice
((
uint8_t
*
)
newCache
.
cudaData
+
newBsStart
*
newCache
.
strides
[
0
]
*
unitSize
+
offset
*
newCache
.
strides
[
1
]
*
unitSize
,
newCache
.
strides
[
0
]
*
unitSize
,
(
uint8_t
*
)
oldCache
.
cudaData
+
oldBsStart
*
oldCache
.
strides
[
0
]
*
unitSize
,
oldCache
.
strides
[
0
]
*
unitSize
,
oldCache
.
dims
[
1
]
*
oldCache
.
dims
[
2
]
*
unitSize
,
bs
);
}
bool
CudaRMSNormOp
::
CanRun
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
bool
CudaRMSNormOp
::
CanRun
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
return
true
;
return
true
;
...
...
src/devices/cuda/cudadevicebatch.cpp
View file @
44be91d3
...
@@ -311,4 +311,46 @@ namespace fastllm {
...
@@ -311,4 +311,46 @@ namespace fastllm {
FastllmCudaMemcpy2DDeviceToDeviceBatch
(
dsts
.
data
(),
dpitchs
.
data
(),
srcs
.
data
(),
FastllmCudaMemcpy2DDeviceToDeviceBatch
(
dsts
.
data
(),
dpitchs
.
data
(),
srcs
.
data
(),
spitchs
.
data
(),
widths
.
data
(),
heights
.
data
(),
dsts
.
size
());
spitchs
.
data
(),
widths
.
data
(),
heights
.
data
(),
dsts
.
size
());
}
}
void
CudaAttentionBatchOp
::
Reshape
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
Data
**
qs
=
(
Data
**
)(
datas
.
find
(
"q"
)
->
second
);
Data
**
ks
=
(
Data
**
)(
datas
.
find
(
"k"
)
->
second
);
Data
**
vs
=
(
Data
**
)(
datas
.
find
(
"v"
)
->
second
);
Data
**
outputs
=
(
Data
**
)(
datas
.
find
(
"output"
)
->
second
);
int
group
=
intParams
.
find
(
"group"
)
!=
intParams
.
end
()
?
intParams
.
find
(
"group"
)
->
second
:
1
;
int
batch
=
intParams
.
find
(
"q___batch"
)
->
second
;
Data
&
q
=
*
qs
[
0
],
&
k
=
*
ks
[
0
],
&
v
=
*
vs
[
0
];
AssertInFastLLM
(
q
.
dims
.
size
()
==
3
&&
k
.
dims
.
size
()
==
3
&&
v
.
dims
.
size
()
==
3
,
"Attention: dims of q, k, v should be 3.
\n
"
);
AssertInFastLLM
(
q
.
dims
[
2
]
==
k
.
dims
[
2
],
"Attention: q.dims[2] should be equal to k.dims[2].
\n
"
);
AssertInFastLLM
(
k
.
dims
[
1
]
==
v
.
dims
[
1
],
"Attention: k.dims[1] should be equal to v.dims[1].
\n
"
);
AssertInFastLLM
(
k
.
dims
[
0
]
==
v
.
dims
[
0
],
"Attention: k.dims[0] should be equal to v.dims[0].
\n
"
);
AssertInFastLLM
(
q
.
dims
[
0
]
==
k
.
dims
[
0
]
*
group
,
"Attention: q.dims[0] should be equal to k.dims[0] * group.
\n
"
);
AssertInFastLLM
(
q
.
dataType
==
k
.
dataType
&&
q
.
dataType
==
v
.
dataType
,
"Attention: q, k, v's datatype should be same.
\n
"
);
AssertInFastLLM
(
q
.
dataType
==
DataType
::
FLOAT32
,
"Attention's input's type should be float32.
\n
"
);
for
(
int
i
=
0
;
i
<
batch
;
i
++
)
{
outputs
[
i
]
->
dataType
=
qs
[
i
]
->
dataType
;
outputs
[
i
]
->
Resize
({
qs
[
i
]
->
dims
[
0
],
qs
[
i
]
->
dims
[
1
],
vs
[
i
]
->
dims
[
2
]});
}
}
void
CudaAttentionBatchOp
::
Run
(
const
std
::
string
&
opType
,
const
fastllm
::
DataDict
&
datas
,
const
fastllm
::
FloatDict
&
floatParams
,
const
fastllm
::
IntDict
&
intParams
)
{
int
batch
=
intParams
.
find
(
"q___batch"
)
->
second
;
Data
**
qs
=
(
Data
**
)(
datas
.
find
(
"q"
)
->
second
);
Data
**
ks
=
(
Data
**
)(
datas
.
find
(
"k"
)
->
second
);
Data
**
vs
=
(
Data
**
)(
datas
.
find
(
"v"
)
->
second
);
Data
**
masks
=
(
Data
**
)(
datas
.
find
(
"mask"
)
->
second
);
Data
**
outputs
=
(
Data
**
)(
datas
.
find
(
"output"
)
->
second
);
for
(
int
i
=
0
;
i
<
batch
;
i
++
)
{
outputs
[
i
]
->
Allocate
();
}
FastllmCudaAttentionBatch
(
qs
,
ks
,
vs
,
masks
,
outputs
,
intParams
.
find
(
"group"
)
->
second
,
floatParams
.
find
(
"scale"
)
->
second
,
intParams
.
find
(
"q___batch"
)
->
second
);
}
}
}
\ No newline at end of file
src/devices/cuda/fastllm-cuda.cu
View file @
44be91d3
...
@@ -800,6 +800,63 @@ __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) {
...
@@ -800,6 +800,63 @@ __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) {
int
input0Stride
=
(
int
)((
size_t
)
pointer
[
id
*
8
+
6
]);
int
input0Stride
=
(
int
)((
size_t
)
pointer
[
id
*
8
+
6
]);
int
input1Stride
=
(
int
)((
size_t
)
pointer
[
id
*
8
+
7
]);
int
input1Stride
=
(
int
)((
size_t
)
pointer
[
id
*
8
+
7
]);
int
tid
=
threadIdx
.
x
;
int
pera
=
4
,
perb
=
4
;
float
cura
[
4
][
4
],
curb
[
4
][
4
],
curc
[
4
][
4
];
int
cnta
=
(
n
-
1
)
/
pera
+
1
,
cntb
=
(
k
-
1
)
/
perb
+
1
;
for
(
int
taskId
=
tid
;
taskId
<
cnta
*
cntb
;
taskId
+=
THREAD_PER_BLOCK
)
{
int
taska
=
taskId
/
cntb
,
taskb
=
taskId
%
cntb
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
cura
[
i
][
j
]
=
0
;
curb
[
i
][
j
]
=
0
;
curc
[
i
][
j
]
=
0
;
}
}
for
(
int
l
=
0
;
l
<
m
;
l
+=
4
)
{
for
(
int
a
=
taska
*
pera
;
a
<
(
taska
+
1
)
*
pera
&&
a
<
n
;
a
++
)
{
#pragma unroll
for
(
int
x
=
0
;
x
<
4
;
x
++
)
{
cura
[
a
-
taska
*
pera
][
x
]
=
input0
[
a
*
input0Stride
+
l
+
x
];
}
}
for
(
int
b
=
taskb
*
perb
;
b
<
(
taskb
+
1
)
*
perb
&&
b
<
k
;
b
++
)
{
#pragma unroll
for
(
int
x
=
0
;
x
<
4
;
x
++
)
{
curb
[
b
-
taskb
*
perb
][
x
]
=
input1
[
b
*
input1Stride
+
l
+
x
];
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
curc
[
i
][
j
]
+=
cura
[
i
][
k
]
*
curb
[
j
][
k
];
}
}
}
}
if
((
taska
+
1
)
*
pera
<=
n
&&
(
taskb
+
1
)
*
perb
<=
k
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
output
[(
taska
*
pera
+
i
)
*
k
+
(
taskb
*
perb
+
j
)]
=
curc
[
i
][
j
]
*
alpha
;
}
}
}
else
{
for
(
int
i
=
0
;
i
<
pera
&&
taska
*
pera
+
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
perb
&&
taskb
*
perb
+
j
<
k
;
j
++
)
{
output
[(
taska
*
pera
+
i
)
*
k
+
(
taskb
*
perb
+
j
)]
=
curc
[
i
][
j
]
*
alpha
;
}
}
}
}
/*
int tid = threadIdx.x;
int tid = threadIdx.x;
for (int i = 0; i < n; i++) {
for (int i = 0; i < n; i++) {
float *curInput0 = input0 + i * input0Stride;
float *curInput0 = input0 + i * input0Stride;
...
@@ -812,6 +869,7 @@ __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) {
...
@@ -812,6 +869,7 @@ __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) {
output[i * k + j] = sum * alpha;
output[i * k + j] = sum * alpha;
}
}
}
}
*/
}
}
template
<
int
THREAD_PER_BLOCK
>
template
<
int
THREAD_PER_BLOCK
>
...
@@ -827,6 +885,64 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) {
...
@@ -827,6 +885,64 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) {
int
input1Stride
=
(
int
)((
size_t
)
pointer
[
id
*
8
+
7
]);
int
input1Stride
=
(
int
)((
size_t
)
pointer
[
id
*
8
+
7
]);
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
pera
=
4
,
perb
=
4
;
float
cura
[
4
][
4
],
curb
[
4
][
4
],
curc
[
4
][
4
];
int
cnta
=
(
n
-
1
)
/
pera
+
1
,
cntb
=
(
k
-
1
)
/
perb
+
1
;
for
(
int
taskId
=
tid
;
taskId
<
cnta
*
cntb
;
taskId
+=
THREAD_PER_BLOCK
)
{
int
taska
=
taskId
/
cntb
,
taskb
=
taskId
%
cntb
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
cura
[
i
][
j
]
=
0
;
curb
[
i
][
j
]
=
0
;
curc
[
i
][
j
]
=
0
;
}
}
for
(
int
l
=
0
;
l
<
m
;
l
+=
4
)
{
for
(
int
a
=
taska
*
pera
;
a
<
(
taska
+
1
)
*
pera
&&
a
<
n
;
a
++
)
{
#pragma unroll
for
(
int
x
=
0
;
x
<
4
;
x
++
)
{
cura
[
a
-
taska
*
pera
][
x
]
=
l
+
x
<
m
?
input0
[
a
*
input0Stride
+
l
+
x
]
:
0
;
}
}
for
(
int
b
=
taskb
*
perb
;
b
<
(
taskb
+
1
)
*
perb
&&
b
<
k
;
b
++
)
{
#pragma unroll
for
(
int
x
=
0
;
x
<
4
;
x
++
)
{
curb
[
b
-
taskb
*
perb
][
x
]
=
l
+
x
<
m
?
input1
[(
l
+
x
)
*
input1Stride
+
b
]
:
0
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
curc
[
i
][
j
]
+=
cura
[
i
][
k
]
*
curb
[
j
][
k
];
}
}
}
}
if
((
taska
+
1
)
*
pera
<=
n
&&
(
taskb
+
1
)
*
perb
<=
k
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
output
[(
taska
*
pera
+
i
)
*
k
+
(
taskb
*
perb
+
j
)]
=
curc
[
i
][
j
]
*
alpha
;
}
}
}
else
{
for
(
int
i
=
0
;
i
<
pera
&&
taska
*
pera
+
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
perb
&&
taskb
*
perb
+
j
<
k
;
j
++
)
{
output
[(
taska
*
pera
+
i
)
*
k
+
(
taskb
*
perb
+
j
)]
=
curc
[
i
][
j
]
*
alpha
;
}
}
}
}
/*
//int tid = threadIdx.x;
for (int i = 0; i < n; i++) {
for (int i = 0; i < n; i++) {
float *curInput0 = input0 + i * input0Stride;
float *curInput0 = input0 + i * input0Stride;
for (int j = tid; j < k; j += THREAD_PER_BLOCK) {
for (int j = tid; j < k; j += THREAD_PER_BLOCK) {
...
@@ -838,6 +954,7 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) {
...
@@ -838,6 +954,7 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) {
output[i * k + j] = sum * alpha;
output[i * k + j] = sum * alpha;
}
}
}
}
*/
}
}
template
<
int
THREAD_PER_BLOCK
>
template
<
int
THREAD_PER_BLOCK
>
...
@@ -880,6 +997,71 @@ __global__ void FastllmAttentionKernel(float *qd, float *kd, float *vd, float *m
...
@@ -880,6 +997,71 @@ __global__ void FastllmAttentionKernel(float *qd, float *kd, float *vd, float *m
}
}
}
}
template
<
int
THREAD_PER_BLOCK
>
__global__
void
FastllmAttentionBatchKernel
(
float
**
pointer
,
float
scale
,
int
group
)
{
const
int
params
=
16
;
int
id
=
blockIdx
.
x
;
float
*
qd
=
(
float
*
)
pointer
[
id
*
params
+
0
];
float
*
kd
=
(
float
*
)
pointer
[
id
*
params
+
1
];
float
*
vd
=
(
float
*
)
pointer
[
id
*
params
+
2
];
float
*
maskd
=
(
float
*
)
pointer
[
id
*
params
+
3
];
float
*
od
=
(
float
*
)
pointer
[
id
*
params
+
4
];
int
q1
=
(
int
)(
unsigned
long
long
)
pointer
[
id
*
params
+
5
];
int
q2
=
(
int
)(
unsigned
long
long
)
pointer
[
id
*
params
+
6
];
int
k1
=
(
int
)(
unsigned
long
long
)
pointer
[
id
*
params
+
7
];
int
v2
=
(
int
)(
unsigned
long
long
)
pointer
[
id
*
params
+
8
];
int
qstride
=
(
int
)(
unsigned
long
long
)
pointer
[
id
*
params
+
9
];
int
kstride
=
(
int
)(
unsigned
long
long
)
pointer
[
id
*
params
+
10
];
int
vstride
=
(
int
)(
unsigned
long
long
)
pointer
[
id
*
params
+
11
];
int
ostride
=
(
int
)(
unsigned
long
long
)
pointer
[
id
*
params
+
12
];
float
*
qk
=
(
float
*
)
pointer
[
id
*
params
+
13
];
float
*
temp
=
(
float
*
)
pointer
[
id
*
params
+
14
];
int
q0
=
(
int
)(
unsigned
long
long
)
pointer
[
id
*
params
+
15
];
for
(
int
o
=
0
;
o
<
q0
;
o
++
)
{
qd
+=
o
*
qstride
;
kd
+=
(
o
/
group
)
*
kstride
;
vd
+=
(
o
/
group
)
*
vstride
;
od
+=
o
*
ostride
;
qk
+=
o
*
k1
;
temp
+=
o
*
k1
;
for
(
int
i
=
0
;
i
<
q1
;
i
++
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
k1
;
j
+=
THREAD_PER_BLOCK
)
{
if
(
maskd
&&
maskd
[
i
*
k1
+
j
]
>
0.99
)
{
qk
[
j
]
=
-
10000
;
continue
;
}
float
sum
=
0.0
f
;
float
*
tempQd
=
qd
+
i
*
q2
,
*
tempKd
=
kd
+
j
*
q2
;
for
(
int
l
=
0
;
l
<
q2
;
l
++
)
{
sum
+=
tempQd
[
l
]
*
tempKd
[
l
];
}
qk
[
j
]
=
sum
*
scale
;
}
__syncthreads
();
FastllmSoftmaxKernelInner1Func
<
THREAD_PER_BLOCK
>
(
qk
,
temp
,
k1
);
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
v2
;
j
+=
THREAD_PER_BLOCK
)
{
float
*
curInput1
=
vd
+
j
;
float
sum
=
0.0
;
for
(
int
l
=
0
;
l
<
k1
;
l
++
)
{
sum
+=
temp
[
l
]
*
curInput1
[
l
*
v2
];
}
od
[
i
*
v2
+
j
]
=
sum
;
}
__syncthreads
();
}
qd
-=
o
*
qstride
;
kd
-=
(
o
/
group
)
*
kstride
;
vd
-=
(
o
/
group
)
*
vstride
;
od
-=
o
*
ostride
;
qk
-=
o
*
k1
;
temp
-=
o
*
k1
;
}
}
void
*
FastllmCudaPrepareInput
(
const
fastllm
::
Data
&
input
)
{
void
*
FastllmCudaPrepareInput
(
const
fastllm
::
Data
&
input
)
{
void
*
ret
;
void
*
ret
;
if
(
input
.
dataDevice
==
fastllm
::
DataDevice
::
CUDA
)
{
if
(
input
.
dataDevice
==
fastllm
::
DataDevice
::
CUDA
)
{
...
@@ -1294,6 +1476,16 @@ std::map<int, std::vector <CudaMemoryBuffer>> cudaBuffersMap;
...
@@ -1294,6 +1476,16 @@ std::map<int, std::vector <CudaMemoryBuffer>> cudaBuffersMap;
std
::
map
<
int
,
size_t
>
noBusyCnt
;
std
::
map
<
int
,
size_t
>
noBusyCnt
;
std
::
map
<
int
,
std
::
vector
<
CudaMemoryBuffer
>>
bigBuffersMap
;
std
::
map
<
int
,
std
::
vector
<
CudaMemoryBuffer
>>
bigBuffersMap
;
void
*
FastllmCudaDirectMalloc
(
size_t
size
)
{
void
*
ret
;
cudaMalloc
(
&
ret
,
size
);
return
ret
;
}
void
FastllmCudaDirectFree
(
void
*
ret
)
{
cudaFree
(
ret
);
}
void
*
FastllmCudaMalloc
(
size_t
size
)
{
void
*
FastllmCudaMalloc
(
size_t
size
)
{
int
id
=
-
1
;
int
id
=
-
1
;
cudaGetDevice
(
&
id
);
cudaGetDevice
(
&
id
);
...
@@ -1302,7 +1494,7 @@ void * FastllmCudaMalloc(size_t size) {
...
@@ -1302,7 +1494,7 @@ void * FastllmCudaMalloc(size_t size) {
int
selId
=
-
1
;
int
selId
=
-
1
;
for
(
int
i
=
0
;
i
<
bigBuffers
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
bigBuffers
.
size
();
i
++
)
{
if
(
bigBuffers
[
i
].
size
>=
size
&&
!
bigBuffers
[
i
].
busy
if
(
bigBuffers
[
i
].
size
>=
size
&&
!
bigBuffers
[
i
].
busy
&&
bigBuffers
[
i
].
size
-
size
<
32
*
1024
*
1024
)
{
&&
bigBuffers
[
i
].
size
-
size
<
1
*
1024
*
1024
)
{
if
(
selId
==
-
1
||
bigBuffers
[
selId
].
size
>
bigBuffers
[
i
].
size
)
{
if
(
selId
==
-
1
||
bigBuffers
[
selId
].
size
>
bigBuffers
[
i
].
size
)
{
selId
=
i
;
selId
=
i
;
}
}
...
@@ -1841,6 +2033,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
...
@@ -1841,6 +2033,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
}
}
FastllmCudaFree
(
qk
);
FastllmCudaFree
(
qk
);
DeviceSync
();
return
true
;
return
true
;
}
}
...
@@ -1896,6 +2089,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
...
@@ -1896,6 +2089,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
}
}
FastllmCudaFree
(
qk
);
FastllmCudaFree
(
qk
);
FastllmCudaFree
(
temp
);
FastllmCudaFree
(
temp
);
DeviceSync
();
return
true
;
return
true
;
}
}
return
true
;
return
true
;
...
@@ -2044,6 +2238,157 @@ bool FastllmCudaApplyLognAttn (fastllm::Data &input, fastllm::Data &lognAttn, fa
...
@@ -2044,6 +2238,157 @@ bool FastllmCudaApplyLognAttn (fastllm::Data &input, fastllm::Data &lognAttn, fa
return
true
;
return
true
;
}
}
bool
FastllmCudaAttentionBatch
(
fastllm
::
Data
**
q
,
fastllm
::
Data
**
k
,
fastllm
::
Data
**
v
,
fastllm
::
Data
**
mask
,
fastllm
::
Data
**
output
,
int
group
,
float
scale
,
int
batch
)
{
int
k0
=
k
[
0
]
->
dims
[
0
];
size_t
memSum
=
0
;
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
memSum
+=
q
[
b
]
->
dims
[
0
]
*
q
[
b
]
->
dims
[
1
]
*
k
[
b
]
->
dims
[
1
];
}
float
*
mem
=
(
float
*
)
FastllmCudaMalloc
(
memSum
*
sizeof
(
float
));
float
**
qk
=
new
float
*
[
batch
];
memSum
=
0
;
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
int
s
=
q
[
b
]
->
dims
[
0
]
*
q
[
b
]
->
dims
[
1
]
*
k
[
b
]
->
dims
[
1
];
qk
[
b
]
=
mem
+
memSum
;
memSum
+=
s
;
}
if
(
true
)
{
uint8_t
**
pointers
=
(
uint8_t
**
)
FastllmCudaMalloc
(
sizeof
(
uint8_t
*
)
*
batch
*
k0
*
8
);
uint8_t
**
cpuPointers
=
new
uint8_t
*
[
batch
*
k0
*
8
];
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
int
i
=
0
;
i
<
k0
;
i
++
)
{
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
0
]
=
(
uint8_t
*
)
q
[
b
]
->
cudaData
+
i
*
group
*
q
[
b
]
->
dims
[
1
]
*
q
[
b
]
->
dims
[
2
]
*
sizeof
(
float
);
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
1
]
=
(
uint8_t
*
)
k
[
b
]
->
cudaData
+
i
*
k
[
b
]
->
strides
[
0
]
*
sizeof
(
float
);
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
2
]
=
(
uint8_t
*
)
qk
[
b
]
+
i
*
group
*
q
[
b
]
->
dims
[
1
]
*
k
[
b
]
->
dims
[
1
]
*
sizeof
(
float
);
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
3
]
=
(
uint8_t
*
)
(
size_t
)
(
group
*
q
[
b
]
->
dims
[
1
]);
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
4
]
=
(
uint8_t
*
)
(
size_t
)
q
[
b
]
->
dims
[
2
];
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
5
]
=
(
uint8_t
*
)
(
size_t
)
k
[
b
]
->
dims
[
1
];
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
6
]
=
(
uint8_t
*
)
(
size_t
)
q
[
b
]
->
strides
[
1
];
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
7
]
=
(
uint8_t
*
)
(
size_t
)
k
[
b
]
->
strides
[
1
];
}
}
cudaMemcpy
(
pointers
,
cpuPointers
,
sizeof
(
uint8_t
*
)
*
batch
*
k0
*
8
,
cudaMemcpyHostToDevice
);
FastllmMatMulTransBBatchKernel
<
128
>
<<<
batch
*
k0
,
128
>>>
(
pointers
,
scale
);
FastllmCudaFree
(
pointers
);
delete
[]
cpuPointers
;
}
if
(
true
)
{
int
total
=
0
;
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
int
outer
=
q
[
b
]
->
dims
[
0
]
*
q
[
b
]
->
dims
[
1
];
total
+=
outer
;
}
uint8_t
**
pointers
=
(
uint8_t
**
)
FastllmCudaMalloc
(
sizeof
(
uint8_t
*
)
*
total
*
3
);
uint8_t
**
cpuPointers
=
new
uint8_t
*
[
total
*
3
];
int
cur
=
0
;
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
int
outer
=
q
[
b
]
->
dims
[
0
]
*
q
[
b
]
->
dims
[
1
];
int
channels
=
k
[
b
]
->
dims
[
1
];
for
(
int
o
=
0
;
o
<
outer
;
o
++
)
{
cpuPointers
[
cur
*
3
+
0
]
=
(
uint8_t
*
)(
qk
[
b
]
+
o
*
channels
);
cpuPointers
[
cur
*
3
+
1
]
=
(
uint8_t
*
)(
qk
[
b
]
+
o
*
channels
);
cpuPointers
[
cur
*
3
+
2
]
=
(
uint8_t
*
)((
size_t
)
channels
);
cur
++
;
}
}
cudaMemcpy
(
pointers
,
cpuPointers
,
sizeof
(
uint8_t
*
)
*
total
*
3
,
cudaMemcpyHostToDevice
);
FastllmSoftmaxKernelBatchInner1
<
256
>
<<<
total
,
256
>>>
(
pointers
);
FastllmCudaFree
(
pointers
);
delete
[]
cpuPointers
;
}
if
(
true
)
{
uint8_t
**
pointers
=
(
uint8_t
**
)
FastllmCudaMalloc
(
sizeof
(
uint8_t
*
)
*
batch
*
k0
*
8
);
uint8_t
**
cpuPointers
=
new
uint8_t
*
[
batch
*
k0
*
8
];
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
int
i
=
0
;
i
<
k0
;
i
++
)
{
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
0
]
=
(
uint8_t
*
)
qk
[
b
]
+
i
*
group
*
q
[
b
]
->
dims
[
1
]
*
k
[
b
]
->
dims
[
1
]
*
sizeof
(
float
);
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
1
]
=
(
uint8_t
*
)
v
[
b
]
->
cudaData
+
i
*
v
[
b
]
->
strides
[
0
]
*
sizeof
(
float
);
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
2
]
=
(
uint8_t
*
)
output
[
b
]
->
cudaData
+
i
*
group
*
q
[
b
]
->
dims
[
1
]
*
v
[
b
]
->
dims
[
2
]
*
sizeof
(
float
);
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
3
]
=
(
uint8_t
*
)
(
size_t
)
(
group
*
q
[
b
]
->
dims
[
1
]);
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
4
]
=
(
uint8_t
*
)
(
size_t
)
k
[
b
]
->
dims
[
1
];
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
5
]
=
(
uint8_t
*
)
(
size_t
)
v
[
b
]
->
dims
[
2
];
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
6
]
=
(
uint8_t
*
)
(
size_t
)
k
[
b
]
->
dims
[
1
];
cpuPointers
[(
b
*
k0
+
i
)
*
8
+
7
]
=
(
uint8_t
*
)
(
size_t
)
v
[
b
]
->
strides
[
1
];
}
}
cudaMemcpy
(
pointers
,
cpuPointers
,
sizeof
(
uint8_t
*
)
*
batch
*
k0
*
8
,
cudaMemcpyHostToDevice
);
FastllmMatMulKernel
<
128
>
<<<
batch
*
k0
,
128
>>>
(
pointers
,
1.0
f
);
FastllmCudaFree
(
pointers
);
delete
[]
cpuPointers
;
}
FastllmCudaFree
(
mem
);
delete
[]
qk
;
/*
{
const int params = 16;
float **pointers = (float **) FastllmCudaMalloc(sizeof(float *) * batch * params);
float **cpuPointers = new float *[batch * params];
float **qk = new float *[batch];
float **temp = new float *[batch];
for (int b = 0; b < batch; b++) {
qk[b] = (float *) FastllmCudaMalloc(q[b]->dims[0] * k[b]->dims[1] * sizeof(float));
temp[b] = (float *) FastllmCudaMalloc(q[b]->dims[0] * k[b]->dims[1] * sizeof(float));
cpuPointers[b * params + 0] = (float *) q[b]->cudaData;
cpuPointers[b * params + 1] = (float *) k[b]->cudaData;
cpuPointers[b * params + 2] = (float *) v[b]->cudaData;
cpuPointers[b * params + 3] = (mask[b] && mask[b]->dims.size() > 0) ? (float *) mask[b]->cudaData : nullptr;
cpuPointers[b * params + 4] = (float *) output[b]->cudaData;
cpuPointers[b * params + 5] = (float *) (unsigned long long) q[b]->dims[1];
cpuPointers[b * params + 6] = (float *) (unsigned long long) q[b]->dims[2];
cpuPointers[b * params + 7] = (float *) (unsigned long long) k[b]->dims[1];
cpuPointers[b * params + 8] = (float *) (unsigned long long) v[b]->dims[2];
cpuPointers[b * params + 9] = (float *) (unsigned long long) q[b]->strides[0];
cpuPointers[b * params + 10] = (float *) (unsigned long long) k[b]->strides[0];
cpuPointers[b * params + 11] = (float *) (unsigned long long) v[b]->strides[0];
cpuPointers[b * params + 12] = (float *) (unsigned long long) output[b]->strides[0];
cpuPointers[b * params + 13] = (float *) (unsigned long long) qk[b];
cpuPointers[b * params + 14] = (float *) (unsigned long long) temp[b];
cpuPointers[b * params + 15] = (float *) (unsigned long long) q[b]->dims[0];
}
cudaMemcpy(pointers, cpuPointers, sizeof(float *) * batch * params, cudaMemcpyHostToDevice);
FastllmAttentionBatchKernel<256> <<< batch, 256 >>>(pointers, scale, group);
for (int i = 0; i < batch; i++) {
FastllmCudaFree(qk[i]);
FastllmCudaFree(temp[i]);
}
delete[] qk;
delete[] temp;
FastllmCudaFree(pointers);
delete[] cpuPointers;
}
*/
/*
for (int b = 0; b < batch; b++) {
int q0 = q[b]->dims[0], q1 = q[b]->dims[1], q2 = q[b]->dims[2], k0 = k[b]->dims[0], k1 = k[b]->dims[1], v2 = v[b]->dims[2];
float *qd = (float *) q[b]->cudaData;
float *kd = (float *) k[b]->cudaData;
float *vd = (float *) v[b]->cudaData;
float *maskd = (mask[b] && mask[b]->dims.size() > 0) ? (float *) mask[b]->cudaData : nullptr;
float *od = (float *) output[b]->cudaData;
int maskBatch = (mask[b] && mask[b]->dims.size() > 0) ? mask[b]->dims[0] : 1;
float *qk = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float));
float *temp = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float));
FastllmAttentionKernel<256> <<<q0, 256>>>(qd, kd, vd, maskd, od,
scale, q1, q2, k1, v2,
group, q[b]->strides[0], k[b]->strides[0], v[b]->strides[0],
output[b]->strides[0],
qk, temp);
}
*/
DeviceSync
();
return
true
;
}
bool
FastllmCudaSplitBatch
(
fastllm
::
Data
&
input
,
fastllm
::
Data
**
outputs
,
int
axis
)
{
bool
FastllmCudaSplitBatch
(
fastllm
::
Data
&
input
,
fastllm
::
Data
**
outputs
,
int
axis
)
{
int
part
=
input
.
dims
[
axis
];
int
part
=
input
.
dims
[
axis
];
int
outer
=
input
.
Count
(
0
)
/
input
.
Count
(
axis
);
int
outer
=
input
.
Count
(
0
)
/
input
.
Count
(
axis
);
...
...
src/executor.cpp
View file @
44be91d3
...
@@ -69,10 +69,10 @@ namespace fastllm {
...
@@ -69,10 +69,10 @@ namespace fastllm {
if
(
intParams
.
find
(
it
.
first
+
"___batch"
)
!=
intParams
.
end
())
{
if
(
intParams
.
find
(
it
.
first
+
"___batch"
)
!=
intParams
.
end
())
{
int
batch
=
intParams
.
find
(
it
.
first
+
"___batch"
)
->
second
;
int
batch
=
intParams
.
find
(
it
.
first
+
"___batch"
)
->
second
;
for
(
int
i
=
0
;
i
<
batch
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch
;
i
++
)
{
lockInCPU
|=
((
Data
**
)
it
.
second
)[
i
]
->
lockInCPU
;
lockInCPU
|=
(((
Data
**
)
it
.
second
)[
i
]
&&
((
Data
**
)
it
.
second
)[
i
]
->
lockInCPU
)
;
}
}
}
else
{
}
else
{
lockInCPU
|=
it
.
second
->
lockInCPU
;
lockInCPU
|=
(
it
.
second
&&
it
.
second
->
lockInCPU
)
;
}
}
}
}
for
(
auto
device
:
devices
)
{
for
(
auto
device
:
devices
)
{
...
@@ -89,10 +89,14 @@ namespace fastllm {
...
@@ -89,10 +89,14 @@ namespace fastllm {
if
(
intParams
.
find
(
it
.
first
+
"___batch"
)
!=
intParams
.
end
())
{
if
(
intParams
.
find
(
it
.
first
+
"___batch"
)
!=
intParams
.
end
())
{
int
batch
=
intParams
.
find
(
it
.
first
+
"___batch"
)
->
second
;
int
batch
=
intParams
.
find
(
it
.
first
+
"___batch"
)
->
second
;
for
(
int
i
=
0
;
i
<
batch
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch
;
i
++
)
{
((
Data
**
)
it
.
second
)[
i
]
->
ToDevice
((
void
*
)
device
);
if
(((
Data
**
)
it
.
second
)[
i
])
{
((
Data
**
)
it
.
second
)[
i
]
->
ToDevice
((
void
*
)
device
);
}
}
}
}
else
{
}
else
{
it
.
second
->
ToDevice
((
void
*
)
device
);
if
(
it
.
second
)
{
it
.
second
->
ToDevice
((
void
*
)
device
);
}
}
}
}
}
device
->
Reshape
(
opType
,
datas
,
floatParams
,
intParams
);
device
->
Reshape
(
opType
,
datas
,
floatParams
,
intParams
);
...
...
src/fastllm.cpp
View file @
44be91d3
...
@@ -368,7 +368,11 @@ namespace fastllm {
...
@@ -368,7 +368,11 @@ namespace fastllm {
this
->
cpuData
=
new
uint8_t
[
this
->
expansionBytes
];
this
->
cpuData
=
new
uint8_t
[
this
->
expansionBytes
];
}
else
if
(
this
->
dataDevice
==
DataDevice
::
CUDA
)
{
}
else
if
(
this
->
dataDevice
==
DataDevice
::
CUDA
)
{
#ifdef USE_CUDA
#ifdef USE_CUDA
this
->
cudaData
=
FastllmCudaMalloc
(
this
->
expansionBytes
);
if
(
this
->
directMemory
)
{
this
->
cudaData
=
FastllmCudaDirectMalloc
(
this
->
expansionBytes
);
}
else
{
this
->
cudaData
=
FastllmCudaMalloc
(
this
->
expansionBytes
);
}
#else
#else
ErrorInFastLLM
(
"Error: cuda is not supported.
\n
"
);
ErrorInFastLLM
(
"Error: cuda is not supported.
\n
"
);
#endif
#endif
...
@@ -382,7 +386,11 @@ namespace fastllm {
...
@@ -382,7 +386,11 @@ namespace fastllm {
delete
[]
this
->
cpuData
;
delete
[]
this
->
cpuData
;
}
else
if
(
this
->
dataDevice
==
DataDevice
::
CUDA
)
{
}
else
if
(
this
->
dataDevice
==
DataDevice
::
CUDA
)
{
#ifdef USE_CUDA
#ifdef USE_CUDA
FastllmCudaFree
(
this
->
cudaData
);
if
(
this
->
directMemory
)
{
FastllmCudaDirectFree
(
this
->
cudaData
);
}
else
{
FastllmCudaFree
(
this
->
cudaData
);
}
#else
#else
ErrorInFastLLM
(
"Error: cuda is not supported.
\n
"
);
ErrorInFastLLM
(
"Error: cuda is not supported.
\n
"
);
#endif
#endif
...
@@ -415,6 +423,7 @@ namespace fastllm {
...
@@ -415,6 +423,7 @@ namespace fastllm {
void
Data
::
Expansion
(
const
std
::
vector
<
int
>
&
dims
)
{
void
Data
::
Expansion
(
const
std
::
vector
<
int
>
&
dims
)
{
if
(
this
->
dims
.
size
()
==
0
)
{
if
(
this
->
dims
.
size
()
==
0
)
{
this
->
directMemory
=
true
;
this
->
strides
.
resize
(
dims
.
size
(),
1
);
this
->
strides
.
resize
(
dims
.
size
(),
1
);
this
->
strides
.
back
()
=
1
;
this
->
strides
.
back
()
=
1
;
for
(
int
i
=
dims
.
size
()
-
2
;
i
>=
0
;
i
--
)
{
for
(
int
i
=
dims
.
size
()
-
2
;
i
>=
0
;
i
--
)
{
...
@@ -489,6 +498,11 @@ namespace fastllm {
...
@@ -489,6 +498,11 @@ namespace fastllm {
#ifdef USE_CUDA
#ifdef USE_CUDA
if
(
this
->
cudaData
!=
nullptr
)
{
if
(
this
->
cudaData
!=
nullptr
)
{
FastllmCudaFree
(
this
->
cudaData
);
FastllmCudaFree
(
this
->
cudaData
);
/*if (this->directMemory) {
FastllmCudaDirectFree(this->cudaData);
} else {
FastllmCudaFree(this->cudaData);
}*/
}
}
#endif
#endif
}
}
...
@@ -524,6 +538,10 @@ namespace fastllm {
...
@@ -524,6 +538,10 @@ namespace fastllm {
}
}
printf("\n");
printf("\n");
*/
*/
// //如果需要打印cuda显存上的数据需要先把数据转到cpu xzhou 20230728
// if (dataDevice == DataDevice::CUDA) {
// ToDevice(DataDevice::CPU);
// }
int
n
=
Count
(
0
)
/
dims
.
back
(),
m
=
dims
.
back
();
int
n
=
Count
(
0
)
/
dims
.
back
(),
m
=
dims
.
back
();
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
10
&&
j
<
m
;
j
++
)
{
for
(
int
j
=
0
;
j
<
10
&&
j
<
m
;
j
++
)
{
...
@@ -548,7 +566,7 @@ namespace fastllm {
...
@@ -548,7 +566,7 @@ namespace fastllm {
weightSum
.
resize
(
n
);
weightSum
.
resize
(
n
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
int
j
=
0
;
int
j
=
0
;
#ifdef __AVX__
#ifdef __AVX
2
__
__m256i
acc
=
_mm256_setzero_si256
();
__m256i
acc
=
_mm256_setzero_si256
();
const
__m256i
ones
=
_mm256_set1_epi16
(
1
);
const
__m256i
ones
=
_mm256_set1_epi16
(
1
);
for
(;
j
+
31
<
m
;
j
+=
32
)
{
for
(;
j
+
31
<
m
;
j
+=
32
)
{
...
@@ -594,7 +612,7 @@ namespace fastllm {
...
@@ -594,7 +612,7 @@ namespace fastllm {
}
}
weightSum
[
i
]
+=
sum0
[
0
]
+
sum0
[
1
]
+
sum0
[
2
]
+
sum0
[
3
];
weightSum
[
i
]
+=
sum0
[
0
]
+
sum0
[
1
]
+
sum0
[
2
]
+
sum0
[
3
];
#endif
#endif
#ifdef __AVX__
#ifdef __AVX
2
__
__m256i
acc
=
_mm256_setzero_si256
();
__m256i
acc
=
_mm256_setzero_si256
();
const
__m256i
lowMask
=
_mm256_set1_epi8
(
0xf
);
const
__m256i
lowMask
=
_mm256_set1_epi8
(
0xf
);
const
__m256i
ones
=
_mm256_set1_epi16
(
1
);
const
__m256i
ones
=
_mm256_set1_epi16
(
1
);
...
@@ -795,6 +813,18 @@ namespace fastllm {
...
@@ -795,6 +813,18 @@ namespace fastllm {
q
.
push
(
SymbolPairs
(
now
->
score
,
l
,
r
,
symbols
[
l
].
len
+
symbols
[
r
].
len
));
q
.
push
(
SymbolPairs
(
now
->
score
,
l
,
r
,
symbols
[
l
].
len
+
symbols
[
r
].
len
));
}
}
int
Tokenizer
::
GetRank
(
std
::
vector
<
Symbol
>
&
symbols
,
std
::
vector
<
std
::
pair
<
int
,
int
>>
&
partitions
,
int
idx
,
int
skip
)
{
if
(
idx
+
skip
+
2
>=
partitions
.
size
())
{
return
std
::
numeric_limits
<
int
>::
max
();
}
auto
s
=
symbols
[
0
].
s
+
symbols
[
0
].
pos
;
std
::
string
key
(
s
+
partitions
[
idx
].
first
,
s
+
partitions
[
idx
+
skip
+
2
].
first
);
if
(
stringToTokenDict
.
find
(
key
)
!=
stringToTokenDict
.
end
())
{
return
stringToTokenDict
[
key
];
}
return
std
::
numeric_limits
<
int
>::
max
();
}
Data
Tokenizer
::
Encode
(
const
std
::
string
&
ori
)
{
Data
Tokenizer
::
Encode
(
const
std
::
string
&
ori
)
{
if
(
this
->
type
==
TokenizerType
::
BPE
)
{
if
(
this
->
type
==
TokenizerType
::
BPE
)
{
std
::
string
blank
=
""
;
std
::
string
blank
=
""
;
...
@@ -926,48 +956,38 @@ namespace fastllm {
...
@@ -926,48 +956,38 @@ namespace fastllm {
if
(
i
==
sep
.
back
().
first
)
{
if
(
i
==
sep
.
back
().
first
)
{
if
(
!
symbols
.
empty
())
{
if
(
!
symbols
.
empty
())
{
symbols
.
back
().
next
=
-
1
;
symbols
.
back
().
next
=
-
1
;
std
::
priority_queue
<
SymbolPairs
>
workQueue
;
std
::
string
cur
=
ori
.
substr
(
i
-
symbols
.
size
(),
symbols
.
size
());
for
(
int
i
=
1
;
i
<
symbols
.
size
();
i
++
)
{
std
::
vector
<
std
::
pair
<
int
,
int
>>
partitions
(
symbols
.
size
()
+
1
);
TryMergePairs
(
symbols
,
i
-
1
,
i
,
workQueue
);
for
(
int
j
=
0
;
j
<=
(
int
)
symbols
.
size
();
j
++
)
{
partitions
[
j
]
=
std
::
make_pair
(
j
,
std
::
numeric_limits
<
int
>::
max
());
}
}
for
(
int
j
=
0
;
j
<
partitions
.
size
()
-
2
;
j
++
)
{
while
(
!
workQueue
.
empty
())
{
partitions
[
j
].
second
=
GetRank
(
symbols
,
partitions
,
j
,
0
);
auto
top
=
workQueue
.
top
();
workQueue
.
pop
();
if
(
symbols
[
top
.
l
].
len
==
0
||
symbols
[
top
.
r
].
len
==
0
||
symbols
[
top
.
l
].
len
+
symbols
[
top
.
r
].
len
!=
top
.
size
)
{
continue
;
}
for
(
int
i
=
symbols
[
top
.
r
].
pos
;
i
<
symbols
[
top
.
r
].
pos
+
symbols
[
top
.
r
].
len
;
i
++
)
{
symbols
[
top
.
l
].
node
=
symbols
[
top
.
l
].
node
->
next
[
symbols
[
top
.
r
].
s
[
i
]];
}
symbols
[
top
.
l
].
len
+=
symbols
[
top
.
r
].
len
;
symbols
[
top
.
r
].
len
=
0
;
symbols
[
top
.
l
].
next
=
symbols
[
top
.
r
].
next
;
if
(
symbols
[
top
.
r
].
next
>=
0
)
{
symbols
[
symbols
[
top
.
r
].
next
].
prev
=
top
.
l
;
}
TryMergePairs
(
symbols
,
symbols
[
top
.
l
].
prev
,
top
.
l
,
workQueue
);
TryMergePairs
(
symbols
,
top
.
l
,
symbols
[
top
.
l
].
next
,
workQueue
);
}
}
while
(
partitions
.
size
()
>
1
)
{
for
(
int
i
=
0
;
i
<
symbols
.
size
();
i
++
)
{
int
min_rank
=
std
::
numeric_limits
<
int
>::
max
();
if
(
symbols
[
i
].
len
>
0
)
{
int
min_rank_idx
=
0
;
v
.
push_back
(
symbols
[
i
].
node
->
tokenId
);
for
(
int
j
=
0
;
j
<
partitions
.
size
()
-
1
;
++
j
)
{
}
else
if
(
symbols
[
i
].
node
==
nullptr
)
{
if
(
partitions
[
j
].
second
<
min_rank
)
{
// 未识别的字符
min_rank
=
partitions
[
j
].
second
;
uint8_t
c
=
(
uint8_t
)
(
symbols
[
i
].
s
[
symbols
[
i
].
pos
]);
min_rank_idx
=
j
;
std
::
string
now
=
"<0x00>"
;
now
[
3
]
=
(
c
/
16
>
9
?
(
'A'
+
c
/
16
-
10
)
:
(
'0'
+
c
/
16
));
now
[
4
]
=
(
c
%
16
>
9
?
(
'A'
+
c
%
16
-
10
)
:
(
'0'
+
c
%
16
));
if
(
stringToTokenDict
.
find
(
now
)
!=
stringToTokenDict
.
end
())
{
v
.
push_back
(
stringToTokenDict
[
now
]);
}
}
}
}
if
(
min_rank
!=
std
::
numeric_limits
<
int
>::
max
())
{
partitions
[
min_rank_idx
].
second
=
GetRank
(
symbols
,
partitions
,
min_rank_idx
,
1
);
if
(
min_rank_idx
>
0
)
{
partitions
[
min_rank_idx
-
1
].
second
=
GetRank
(
symbols
,
partitions
,
min_rank_idx
-
1
,
1
);
}
partitions
.
erase
(
partitions
.
begin
()
+
min_rank_idx
+
1
);
}
else
{
break
;
}
}
}
symbols
.
clear
();
symbols
.
clear
();
for
(
int
j
=
0
;
j
<
partitions
.
size
()
-
1
;
j
++
)
{
std
::
string
key
=
cur
.
substr
(
partitions
[
j
].
first
,
partitions
[
j
+
1
].
first
-
partitions
[
j
].
first
);
v
.
push_back
((
float
)
stringToTokenDict
[
key
]);
}
}
}
std
::
string
special
=
ori
.
substr
(
sep
.
back
().
first
,
sep
.
back
().
second
);
std
::
string
special
=
ori
.
substr
(
sep
.
back
().
first
,
sep
.
back
().
second
);
...
@@ -1592,6 +1612,14 @@ namespace fastllm {
...
@@ -1592,6 +1612,14 @@ namespace fastllm {
}
}
}
}
void
CopyKVCache
(
Data
&
oldCache
,
Data
&
newCache
,
int
oldBsStart
,
int
newBsStart
,
int
bs
,
int
offset
)
{
curExecutor
->
Run
(
"CopyKVCache"
,
{
{
"oldCache"
,
(
Data
*
)
&
oldCache
},
{
"newCache"
,
(
Data
*
)
&
newCache
}
},
{},
{
{
"oldBsStart"
,
oldBsStart
},
{
"newBsStart"
,
newBsStart
},
{
"bs"
,
bs
},
{
"offset"
,
offset
}
});
}
void
Attention
(
const
Data
&
q
,
const
Data
&
k
,
const
Data
&
v
,
const
Data
&
mask
,
Data
&
output
,
void
Attention
(
const
Data
&
q
,
const
Data
&
k
,
const
Data
&
v
,
const
Data
&
mask
,
Data
&
output
,
int
group
,
float
scale
,
int
attentionType
)
{
int
group
,
float
scale
,
int
attentionType
)
{
curExecutor
->
Run
(
"Attention"
,
{
curExecutor
->
Run
(
"Attention"
,
{
...
@@ -1814,6 +1842,21 @@ namespace fastllm {
...
@@ -1814,6 +1842,21 @@ namespace fastllm {
},
{},
{{
"axis"
,
axis
},
{
"input0___batch"
,
(
int
)
input0
.
size
()},
{
"input1___batch"
,
(
int
)
input1
.
size
()}});
},
{},
{{
"axis"
,
axis
},
{
"input0___batch"
,
(
int
)
input0
.
size
()},
{
"input1___batch"
,
(
int
)
input1
.
size
()}});
}
}
void
AttentionBatch
(
std
::
vector
<
Data
*>
&
q
,
std
::
vector
<
Data
*>
&
k
,
std
::
vector
<
Data
*>
&
v
,
std
::
vector
<
Data
*>
&
mask
,
std
::
vector
<
Data
*>
&
output
,
int
group
,
float
scale
,
int
attentionType
)
{
curExecutor
->
Run
(
"AttentionBatch"
,
{
{
"q"
,
(
Data
*
)
q
.
data
()},
{
"k"
,
(
Data
*
)
k
.
data
()},
{
"v"
,
(
Data
*
)
v
.
data
()},
{
"mask"
,
(
Data
*
)
mask
.
data
()},
{
"output"
,
(
Data
*
)
output
.
data
()}
},
{{
"scale"
,
scale
}},
{
{
"group"
,
group
},
{
"q___batch"
,
(
int
)
q
.
size
()},
{
"k___batch"
,
(
int
)
k
.
size
()},
{
"v___batch"
,
(
int
)
v
.
size
()},
{
"mask___batch"
,
(
int
)
mask
.
size
()},
{
"output___batch"
,
(
int
)
output
.
size
()}
});
}
void
LoraLayer
(
Data
&
input
,
Data
&
weight
,
Data
&
loraA
,
Data
&
loraB
,
const
Data
&
bias
,
Data
&
output
,
void
LoraLayer
(
Data
&
input
,
Data
&
weight
,
Data
&
loraA
,
Data
&
loraB
,
const
Data
&
bias
,
Data
&
output
,
std
::
map
<
std
::
string
,
std
::
string
>
loraConfig
)
{
std
::
map
<
std
::
string
,
std
::
string
>
loraConfig
)
{
float
r
=
std
::
atof
(
loraConfig
[
"r"
].
c_str
());
float
r
=
std
::
atof
(
loraConfig
[
"r"
].
c_str
());
...
...
src/models/basellm.cpp
View file @
44be91d3
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "basellm.h"
#include "basellm.h"
#include "utils.h"
#include "utils.h"
#include <sstream>
#include <sstream>
#include <cstring>
#ifdef USE_CUDA
#ifdef USE_CUDA
#include "fastllm-cuda.cuh"
#include "fastllm-cuda.cuh"
...
@@ -339,10 +340,24 @@ namespace fastllm {
...
@@ -339,10 +340,24 @@ namespace fastllm {
LastTokensManager
tokensManager
;
LastTokensManager
tokensManager
;
std
::
vector
<
std
::
vector
<
float
>*
>
logits
;
std
::
vector
<
std
::
vector
<
float
>*
>
logits
;
model
->
dictLocker
.
lock
();
model
->
dictLocker
.
lock
();
int
limit
=
model
->
tokensLimit
>
0
?
model
->
tokensLimit
:
1e9
;
int
lenSum
=
0
;
for
(
auto
&
it
:
model
->
responseContextDict
.
dicts
)
{
if
(
it
.
second
->
pastKeyValues
[
0
].
first
.
expansionDims
.
size
()
>
0
&&
!
it
.
second
->
isEnding
)
{
lenSum
+=
it
.
second
->
pastKeyValues
[
0
].
first
.
expansionDims
[
1
];
}
}
for
(
int
isPrompt
=
1
;
isPrompt
>=
0
;
isPrompt
--
)
{
for
(
int
isPrompt
=
1
;
isPrompt
>=
0
;
isPrompt
--
)
{
int
cnt
=
0
;
if
(
isPrompt
==
0
&&
seqLens
.
size
()
>
0
)
{
if
(
isPrompt
==
0
&&
seqLens
.
size
()
>
0
)
{
continue
;
continue
;
}
}
if
(
lenSum
>
limit
&&
isPrompt
)
{
continue
;
}
for
(
auto
&
it
:
model
->
responseContextDict
.
dicts
)
{
for
(
auto
&
it
:
model
->
responseContextDict
.
dicts
)
{
if
(
it
.
second
->
isEnding
)
{
if
(
it
.
second
->
isEnding
)
{
continue
;
continue
;
...
@@ -350,6 +365,16 @@ namespace fastllm {
...
@@ -350,6 +365,16 @@ namespace fastllm {
if
(
isPrompt
&&
it
.
second
->
preTokens
!=
0
)
{
if
(
isPrompt
&&
it
.
second
->
preTokens
!=
0
)
{
continue
;
continue
;
}
}
if
(
!
isPrompt
&&
it
.
second
->
preTokens
==
0
)
{
continue
;
}
int
outputLimit
=
it
.
second
->
generationConfig
.
output_token_limit
;
outputLimit
=
(
outputLimit
<
0
?
128
:
outputLimit
);
if
(
isPrompt
&&
lenSum
+
it
.
second
->
currentTokens
.
size
()
+
outputLimit
>
limit
)
{
continue
;
}
generationConfigs
.
push_back
(
it
.
second
->
generationConfig
);
generationConfigs
.
push_back
(
it
.
second
->
generationConfig
);
if
(
it
.
second
->
generationConfig
.
output_logits
)
{
if
(
it
.
second
->
generationConfig
.
output_logits
)
{
it
.
second
->
resultLogits
.
push
(
new
std
::
vector
<
float
>
());
it
.
second
->
resultLogits
.
push
(
new
std
::
vector
<
float
>
());
...
@@ -397,6 +422,7 @@ namespace fastllm {
...
@@ -397,6 +422,7 @@ namespace fastllm {
&
it
.
second
->
pastKeyValues
[
i
].
second
));
&
it
.
second
->
pastKeyValues
[
i
].
second
));
}
}
if
(
isPrompt
)
{
if
(
isPrompt
)
{
cnt
+=
it
.
second
->
currentTokens
.
size
();
break
;
break
;
}
}
}
}
...
@@ -412,6 +438,8 @@ namespace fastllm {
...
@@ -412,6 +438,8 @@ namespace fastllm {
#endif
#endif
Data
inputIds
=
Data
(
DataType
::
FLOAT32
,
{
1
,
(
int
)
ids
.
size
()},
ids
);
Data
inputIds
=
Data
(
DataType
::
FLOAT32
,
{
1
,
(
int
)
ids
.
size
()},
ids
);
std
::
vector
<
int
>
ret
;
std
::
vector
<
int
>
ret
;
auto
st
=
std
::
chrono
::
system_clock
::
now
();
//ClearProfiler();
if
(
seqLens
.
size
()
>
1
)
{
if
(
seqLens
.
size
()
>
1
)
{
ret
=
model
->
ForwardBatch
(
seqLens
.
size
(),
inputIds
,
attentionMasks
,
ret
=
model
->
ForwardBatch
(
seqLens
.
size
(),
inputIds
,
attentionMasks
,
positionIds
,
seqLens
,
pastKeyValues
,
generationConfigs
,
positionIds
,
seqLens
,
pastKeyValues
,
generationConfigs
,
...
@@ -422,7 +450,13 @@ namespace fastllm {
...
@@ -422,7 +450,13 @@ namespace fastllm {
*
positionIds
[
0
],
*
positionIds
[
0
],
*
pastKeyValue1
,
generationConfigs
[
0
],
tokensManager
,
logits
[
0
])};
*
pastKeyValue1
,
generationConfigs
[
0
],
tokensManager
,
logits
[
0
])};
}
}
//PrintProfiler();
/*
static int tot = 0;
printf("len = %d, spend = %f s.\n", (int)seqLens.size(), GetSpan(st, std::chrono::system_clock::now()));
tot += (int)seqLens.size();
printf("tot = %d\n", tot);
*/
model
->
dictLocker
.
lock
();
model
->
dictLocker
.
lock
();
for
(
int
i
=
0
;
i
<
handles
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
handles
.
size
();
i
++
)
{
auto
&
it
=
*
model
->
responseContextDict
.
dicts
.
find
(
handles
[
i
]);
auto
&
it
=
*
model
->
responseContextDict
.
dicts
.
find
(
handles
[
i
]);
...
...
src/models/chatglm.cpp
View file @
44be91d3
...
@@ -190,7 +190,7 @@ namespace fastllm {
...
@@ -190,7 +190,7 @@ namespace fastllm {
if
(
pastKey
.
Count
(
0
)
==
0
||
pastKey
.
dims
.
size
()
==
0
)
{
if
(
pastKey
.
Count
(
0
)
==
0
||
pastKey
.
dims
.
size
()
==
0
)
{
newDims
=
std
::
vector
<
int
>
{
k
.
dims
[
0
],
((
k
.
dims
[
1
]
-
1
)
/
unitLen
+
1
)
*
unitLen
,
k
.
dims
[
2
]};
newDims
=
std
::
vector
<
int
>
{
k
.
dims
[
0
],
((
k
.
dims
[
1
]
-
1
)
/
unitLen
+
1
)
*
unitLen
,
k
.
dims
[
2
]};
if
(
generationConfig
.
output_token_limit
>
0
)
{
if
(
generationConfig
.
output_token_limit
>
0
)
{
newDims
[
1
]
=
std
::
min
(
newDims
[
1
],
k
.
dims
[
1
]
+
generationConfig
.
output_token_limit
)
;
newDims
[
1
]
=
k
.
dims
[
1
]
+
generationConfig
.
output_token_limit
;
}
}
}
else
{
}
else
{
newDims
=
pastKey
.
dims
;
newDims
=
pastKey
.
dims
;
...
@@ -207,7 +207,7 @@ namespace fastllm {
...
@@ -207,7 +207,7 @@ namespace fastllm {
if
(
pastValue
.
Count
(
0
)
==
0
||
pastValue
.
dims
.
size
()
==
0
)
{
if
(
pastValue
.
Count
(
0
)
==
0
||
pastValue
.
dims
.
size
()
==
0
)
{
newDims
=
std
::
vector
<
int
>
{
v
.
dims
[
0
],
((
v
.
dims
[
1
]
-
1
)
/
unitLen
+
1
)
*
unitLen
,
v
.
dims
[
2
]};
newDims
=
std
::
vector
<
int
>
{
v
.
dims
[
0
],
((
v
.
dims
[
1
]
-
1
)
/
unitLen
+
1
)
*
unitLen
,
v
.
dims
[
2
]};
if
(
generationConfig
.
output_token_limit
>
0
)
{
if
(
generationConfig
.
output_token_limit
>
0
)
{
newDims
[
1
]
=
std
::
min
(
newDims
[
1
],
k
.
dims
[
1
]
+
generationConfig
.
output_token_limit
)
;
newDims
[
1
]
=
k
.
dims
[
1
]
+
generationConfig
.
output_token_limit
;
}
}
}
else
{
}
else
{
newDims
=
pastValue
.
dims
;
newDims
=
pastValue
.
dims
;
...
@@ -377,12 +377,12 @@ namespace fastllm {
...
@@ -377,12 +377,12 @@ namespace fastllm {
CatDirect
(
*
(
Data
*
)
positionIds
[
0
],
*
(
Data
*
)
positionIds
[
i
],
1
);
CatDirect
(
*
(
Data
*
)
positionIds
[
0
],
*
(
Data
*
)
positionIds
[
i
],
1
);
}
}
}
}
std
::
vector
<
Data
*>
keys
,
values
,
qs
,
attns
,
masks
,
contexts
;
std
::
vector
<
Data
*>
keys
,
values
,
qs
,
attns
,
contexts
;
keys
.
resize
(
batch
);
keys
.
resize
(
batch
);
values
.
resize
(
batch
);
values
.
resize
(
batch
);
qs
.
resize
(
batch
);
qs
.
resize
(
batch
);
attns
.
resize
(
batch
);
attns
.
resize
(
batch
);
masks
.
resize
(
batch
);
contexts
.
resize
(
batch
);
contexts
.
resize
(
batch
);
std
::
vector
<
Data
*>
pointersK
,
pointersV
,
pointersQ
;
std
::
vector
<
Data
*>
pointersK
,
pointersV
,
pointersQ
;
...
@@ -486,6 +486,10 @@ namespace fastllm {
...
@@ -486,6 +486,10 @@ namespace fastllm {
auto
&
q
=
curQs
[
b
],
&
k
=
curKs
[
b
],
&
v
=
curVs
[
b
];
auto
&
q
=
curQs
[
b
],
&
k
=
curKs
[
b
],
&
v
=
curVs
[
b
];
Data
&
pastKey
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
first
,
&
pastValue
=
*
pastKeyValues
[
b
*
block_cnt
+
Data
&
pastKey
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
first
,
&
pastValue
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
second
;
i
].
second
;
if
(
pastKey
.
dims
.
size
()
>
0
&&
pastKey
.
dims
[
1
]
+
k
.
dims
[
1
]
<=
pastKey
.
expansionDims
[
1
])
{
continue
;
}
pastKey
.
ToDevice
(
DataDevice
::
CUDA
);
pastKey
.
ToDevice
(
DataDevice
::
CUDA
);
pastValue
.
ToDevice
(
DataDevice
::
CUDA
);
pastValue
.
ToDevice
(
DataDevice
::
CUDA
);
...
@@ -533,64 +537,76 @@ namespace fastllm {
...
@@ -533,64 +537,76 @@ namespace fastllm {
}
}
CatDirectBatch
(
keys
,
pointersK
,
1
);
CatDirectBatch
(
keys
,
pointersK
,
1
);
CatDirectBatch
(
values
,
pointersV
,
1
);
CatDirectBatch
(
values
,
pointersV
,
1
);
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
auto
&
q
=
curQs
[
b
];
Data
&
pastKey
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
first
;
outputSizes
[
b
]
=
{
1
,
q
.
dims
[
0
],
q
.
dims
[
1
],
pastKey
.
dims
[
1
]};
q
.
Reshape
({
pastKey
.
dims
[
0
],
-
1
,
q
.
dims
[
2
]});
}
// 1.2 Attention
// 1.2.0 q * k^T
if
(
all1
&&
batch
>
1
)
{
if
(
all1
&&
batch
>
1
)
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
qs
[
b
]
=
(
&
curQs
[
b
]);
qs
[
b
]
=
(
&
curQs
[
b
]);
keys
[
b
]
=
(
pastKeyValues
[
b
*
block_cnt
+
i
].
first
);
keys
[
b
]
=
(
pastKeyValues
[
b
*
block_cnt
+
i
].
first
);
attns
[
b
]
=
(
&
attnProbs
[
b
]);
values
[
b
]
=
(
pastKeyValues
[
b
*
block_cnt
+
i
].
second
);
masks
[
b
]
=
attentionMask
[
b
];
contexts
[
b
]
=
(
&
curContextLayer
[
b
]);
outputSizes
[
b
]
=
{
1
,
qs
[
b
]
->
dims
[
0
],
qs
[
b
]
->
dims
[
1
],
keys
[
b
]
->
dims
[
1
]};
}
}
MatMulTransB
Batch
(
qs
,
keys
,
attns
,
1.0
/
(
scale_attn
*
(
i
+
1
))
);
Attention
Batch
(
qs
,
keys
,
values
,
masks
,
contexts
,
qs
[
0
]
->
dims
[
0
]
/
values
[
0
]
->
dims
[
0
]
,
1.0
/
scale_attn
,
1
);
}
else
{
}
else
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
auto
&
q
=
curQs
[
b
];
auto
&
q
=
curQs
[
b
];
Data
&
pastKey
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
first
;
Data
&
pastKey
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
first
;
MatMulTransB
(
q
,
pastKey
,
attnProbs
[
b
],
1.0
/
(
scale_attn
*
(
i
+
1
)));
outputSizes
[
b
]
=
{
1
,
q
.
dims
[
0
],
q
.
dims
[
1
],
pastKey
.
dims
[
1
]};
q
.
Reshape
({
pastKey
.
dims
[
0
],
-
1
,
q
.
dims
[
2
]});
}
}
}
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
// 1.2 Attention
attnProbs
[
b
].
Reshape
(
outputSizes
[
b
]);
// 1.2.0 q * k^T
// 1.2.1 Mask
if
(
all1
&&
batch
>
1
)
{
if
(
attentionMask
[
b
]
!=
nullptr
)
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
AttentionMask
(
attnProbs
[
b
],
*
attentionMask
[
b
],
-
10000
);
qs
[
b
]
=
(
&
curQs
[
b
]);
keys
[
b
]
=
(
pastKeyValues
[
b
*
block_cnt
+
i
].
first
);
attns
[
b
]
=
(
&
attnProbs
[
b
]);
}
MatMulTransBBatch
(
qs
,
keys
,
attns
,
1.0
/
(
scale_attn
*
(
i
+
1
)));
}
else
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
auto
&
q
=
curQs
[
b
];
Data
&
pastKey
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
first
;
MatMulTransB
(
q
,
pastKey
,
attnProbs
[
b
],
1.0
/
(
scale_attn
*
(
i
+
1
)));
}
}
}
}
// 1.2.2 softmax
for
(
int
i
=
0
;
i
<
attnProbs
.
size
();
i
++
)
{
attns
[
i
]
=
(
&
attnProbs
[
i
]);
}
MulBatch
(
attns
,
i
+
1
,
attns
);
SoftmaxBatch
(
attns
,
attns
,
-
1
);
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
Data
&
pastValue
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
second
;
outputSizes
[
b
]
=
{
1
,
num_attention_heads
,
-
1
,
pastValue
.
dims
[
2
]};
attnProbs
[
b
].
Reshape
({
pastValue
.
dims
[
0
],
-
1
,
attnProbs
[
b
].
dims
[
3
]});
}
// 1.2.3 prob * v
if
(
all1
&&
batch
>
1
)
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
attns
[
b
]
=
(
&
attnProbs
[
b
]);
attnProbs
[
b
].
Reshape
(
outputSizes
[
b
]);
values
[
b
]
=
(
pastKeyValues
[
b
*
block_cnt
+
i
].
second
);
// 1.2.1 Mask
contexts
[
b
]
=
(
&
curContextLayer
[
b
]);
if
(
attentionMask
[
b
]
!=
nullptr
)
{
AttentionMask
(
attnProbs
[
b
],
*
attentionMask
[
b
],
-
10000
);
}
}
}
MatMulBatch
(
attns
,
values
,
contexts
);
}
else
{
// 1.2.2 softmax
for
(
int
i
=
0
;
i
<
attnProbs
.
size
();
i
++
)
{
attns
[
i
]
=
(
&
attnProbs
[
i
]);
}
MulBatch
(
attns
,
i
+
1
,
attns
);
SoftmaxBatch
(
attns
,
attns
,
-
1
);
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
Data
&
pastValue
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
second
;
Data
&
pastValue
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
second
;
MatMul
(
attnProbs
[
b
],
pastValue
,
curContextLayer
[
b
]);
outputSizes
[
b
]
=
{
1
,
num_attention_heads
,
-
1
,
pastValue
.
dims
[
2
]};
attnProbs
[
b
].
Reshape
({
pastValue
.
dims
[
0
],
-
1
,
attnProbs
[
b
].
dims
[
3
]});
}
// 1.2.3 prob * v
if
(
all1
&&
batch
>
1
)
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
attns
[
b
]
=
(
&
attnProbs
[
b
]);
values
[
b
]
=
(
pastKeyValues
[
b
*
block_cnt
+
i
].
second
);
contexts
[
b
]
=
(
&
curContextLayer
[
b
]);
}
MatMulBatch
(
attns
,
values
,
contexts
);
}
else
{
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
Data
&
pastValue
=
*
pastKeyValues
[
b
*
block_cnt
+
i
].
second
;
MatMul
(
attnProbs
[
b
],
pastValue
,
curContextLayer
[
b
]);
}
}
}
}
}
if
(
all1
)
{
if
(
all1
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment