Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
d7965f91
Commit
d7965f91
authored
Jul 18, 2025
by
wooway777
Browse files
issue/21 - Initial Modualization
parent
f59c7bf5
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
599 additions
and
245 deletions
+599
-245
src/models/cache_manager.hpp
src/models/cache_manager.hpp
+295
-0
src/models/inference_context.cpp
src/models/inference_context.cpp
+188
-0
src/models/inference_context.hpp
src/models/inference_context.hpp
+39
-0
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+73
-244
src/tensor.hpp
src/tensor.hpp
+1
-0
src/tensor/strorage.cpp
src/tensor/strorage.cpp
+1
-1
src/tensor/tensor.cpp
src/tensor/tensor.cpp
+1
-0
xmake.lua
xmake.lua
+1
-0
No files found.
src/models/cache_manager.hpp
0 → 100644
View file @
d7965f91
#ifndef CACHE_MANAGER_HPP
#define CACHE_MANAGER_HPP
#include <functional>
#include <memory>
#include <unordered_map>
#include <vector>
#include "../tensor.hpp"
#include "../utils.hpp"
#include "infinicore_infer.h"
// Hash combine utility (similar to boost::hash_combine)
inline
void
hash_combine
(
size_t
&
seed
,
size_t
value
)
{
seed
^=
value
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
);
}
// Specialization for enum types
template
<
typename
T
>
inline
void
hash_combine
(
size_t
&
seed
,
T
value
,
typename
std
::
enable_if
<
std
::
is_enum
<
T
>::
value
>::
type
*
=
0
)
{
hash_combine
(
seed
,
static_cast
<
size_t
>
(
value
));
}
// Specialization for float to handle potential precision issues
inline
void
hash_combine
(
size_t
&
seed
,
float
value
)
{
// Treat float bits as uint32_t for consistent hashing
uint32_t
int_value
;
static_assert
(
sizeof
(
value
)
==
sizeof
(
int_value
),
"Size mismatch"
);
std
::
memcpy
(
&
int_value
,
&
value
,
sizeof
(
value
));
hash_combine
(
seed
,
static_cast
<
size_t
>
(
int_value
));
}
// Helper function to compute hash for tensor descriptors
inline
size_t
computeTensorDescHash
(
std
::
shared_ptr
<
TensorDesc
>
desc
)
{
size_t
seed
=
0
;
hash_combine
(
seed
,
desc
->
dtype
());
for
(
auto
dim
:
desc
->
shape
())
{
hash_combine
(
seed
,
dim
);
}
for
(
auto
stride
:
desc
->
strides
())
{
hash_combine
(
seed
,
static_cast
<
size_t
>
(
stride
));
}
return
seed
;
}
enum
class
OperatorType
{
RMS_NORM
,
GEMM
,
ROPE
,
REARRANGE
,
CAUSAL_SOFTMAX
,
SWIGLU
,
RANDOM_SAMPLE
};
template
<
typename
DescriptorType
>
class
LRUDescriptorCache
{
private:
struct
CacheNode
{
size_t
key
;
DescriptorType
desc
;
CacheNode
*
prev
;
CacheNode
*
next
;
CacheNode
()
:
key
(
0
),
desc
(),
prev
(
nullptr
),
next
(
nullptr
)
{}
CacheNode
(
size_t
k
,
const
DescriptorType
&
d
)
:
key
(
k
),
desc
(
d
),
prev
(
nullptr
),
next
(
nullptr
)
{}
};
std
::
unordered_map
<
size_t
,
CacheNode
*>
cache
;
CacheNode
*
head
;
CacheNode
*
tail
;
const
size_t
capacity
;
size_t
size
;
const
OperatorType
opType
;
void
destroyDescriptor
(
DescriptorType
&
desc
)
{
switch
(
opType
)
{
case
OperatorType
::
RMS_NORM
:
infiniopDestroyRMSNormDescriptor
(
desc
);
break
;
case
OperatorType
::
GEMM
:
infiniopDestroyGemmDescriptor
(
desc
);
break
;
case
OperatorType
::
ROPE
:
infiniopDestroyRoPEDescriptor
(
desc
);
break
;
case
OperatorType
::
REARRANGE
:
infiniopDestroyRearrangeDescriptor
(
desc
);
break
;
case
OperatorType
::
CAUSAL_SOFTMAX
:
infiniopDestroyCausalSoftmaxDescriptor
(
desc
);
break
;
case
OperatorType
::
SWIGLU
:
infiniopDestroySwiGLUDescriptor
(
desc
);
break
;
case
OperatorType
::
RANDOM_SAMPLE
:
infiniopDestroyRandomSampleDescriptor
(
desc
);
break
;
default:
throw
std
::
runtime_error
(
"Unknown descriptor type"
);
}
}
void
removeNode
(
CacheNode
*
node
)
{
node
->
prev
->
next
=
node
->
next
;
node
->
next
->
prev
=
node
->
prev
;
destroyDescriptor
(
node
->
desc
);
cache
.
erase
(
node
->
key
);
delete
node
;
--
size
;
}
void
addToTop
(
CacheNode
*
node
)
{
node
->
next
=
head
->
next
;
node
->
next
->
prev
=
node
;
node
->
prev
=
head
;
head
->
next
=
node
;
cache
[
node
->
key
]
=
node
;
if
(
++
size
>
capacity
)
{
removeNode
(
tail
->
prev
);
}
}
void
moveToTop
(
CacheNode
*
node
)
{
node
->
prev
->
next
=
node
->
next
;
node
->
next
->
prev
=
node
->
prev
;
node
->
next
=
head
->
next
;
node
->
next
->
prev
=
node
;
node
->
prev
=
head
;
head
->
next
=
node
;
}
public:
LRUDescriptorCache
(
size_t
c
,
OperatorType
t
)
:
capacity
(
c
),
size
(
0
),
opType
(
t
)
{
head
=
new
CacheNode
();
tail
=
new
CacheNode
();
head
->
next
=
tail
;
tail
->
prev
=
head
;
}
~
LRUDescriptorCache
()
{
while
(
head
->
next
!=
tail
)
{
removeNode
(
head
->
next
);
}
delete
head
;
delete
tail
;
}
bool
get
(
size_t
key
,
DescriptorType
&
out_desc
)
{
auto
it
=
cache
.
find
(
key
);
if
(
it
==
cache
.
end
())
{
return
false
;
}
CacheNode
*
node
=
it
->
second
;
moveToTop
(
node
);
out_desc
=
node
->
desc
;
return
true
;
}
void
put
(
size_t
key
,
const
DescriptorType
&
descriptor
)
{
auto
it
=
cache
.
find
(
key
);
if
(
it
!=
cache
.
end
())
{
// Key already exists, update the descriptor
CacheNode
*
node
=
it
->
second
;
destroyDescriptor
(
node
->
desc
);
node
->
desc
=
descriptor
;
moveToTop
(
node
);
return
;
}
// Check if we need to evict
if
(
size
>=
capacity
)
{
removeNode
(
tail
->
prev
);
}
// Create new node and add to top
CacheNode
*
node
=
new
CacheNode
(
key
,
descriptor
);
addToTop
(
node
);
}
LRUDescriptorCache
(
const
LRUDescriptorCache
&
)
=
delete
;
LRUDescriptorCache
&
operator
=
(
const
LRUDescriptorCache
&
)
=
delete
;
};
class
CacheManager
{
private:
const
size_t
DEFAULT_CACHE_CAPACITY
=
100
;
LRUDescriptorCache
<
infiniopRMSNormDescriptor_t
>
rms_norm_cache
;
LRUDescriptorCache
<
infiniopGemmDescriptor_t
>
gemm_cache
;
LRUDescriptorCache
<
infiniopRoPEDescriptor_t
>
rope_cache
;
LRUDescriptorCache
<
infiniopRearrangeDescriptor_t
>
rearrange_cache
;
LRUDescriptorCache
<
infiniopCausalSoftmaxDescriptor_t
>
causal_softmax_cache
;
LRUDescriptorCache
<
infiniopSwiGLUDescriptor_t
>
swiglu_cache
;
LRUDescriptorCache
<
infiniopRandomSampleDescriptor_t
>
random_sample_cache
;
public:
CacheManager
(
size_t
capacity
=
100
)
:
rms_norm_cache
(
capacity
,
OperatorType
::
RMS_NORM
),
gemm_cache
(
capacity
,
OperatorType
::
GEMM
),
rope_cache
(
capacity
,
OperatorType
::
ROPE
),
rearrange_cache
(
capacity
,
OperatorType
::
REARRANGE
),
causal_softmax_cache
(
capacity
,
OperatorType
::
CAUSAL_SOFTMAX
),
swiglu_cache
(
capacity
,
OperatorType
::
SWIGLU
),
random_sample_cache
(
capacity
,
OperatorType
::
RANDOM_SAMPLE
)
{}
// RMSNorm operations
bool
getRMSNormDescriptor
(
size_t
key
,
infiniopRMSNormDescriptor_t
&
desc
)
{
return
rms_norm_cache
.
get
(
key
,
desc
);
}
void
putRMSNormDescriptor
(
size_t
key
,
const
infiniopRMSNormDescriptor_t
&
desc
)
{
rms_norm_cache
.
put
(
key
,
desc
);
}
// GEMM operations
bool
getGemmDescriptor
(
size_t
key
,
infiniopGemmDescriptor_t
&
desc
)
{
return
gemm_cache
.
get
(
key
,
desc
);
}
void
putGemmDescriptor
(
size_t
key
,
const
infiniopGemmDescriptor_t
&
desc
)
{
gemm_cache
.
put
(
key
,
desc
);
}
// RoPE operations
bool
getRoPEDescriptor
(
size_t
key
,
infiniopRoPEDescriptor_t
&
desc
)
{
return
rope_cache
.
get
(
key
,
desc
);
}
void
putRoPEDescriptor
(
size_t
key
,
const
infiniopRoPEDescriptor_t
&
desc
)
{
rope_cache
.
put
(
key
,
desc
);
}
// Rearrange operations
bool
getRearrangeDescriptor
(
size_t
key
,
infiniopRearrangeDescriptor_t
&
desc
)
{
return
rearrange_cache
.
get
(
key
,
desc
);
}
void
putRearrangeDescriptor
(
size_t
key
,
const
infiniopRearrangeDescriptor_t
&
desc
)
{
rearrange_cache
.
put
(
key
,
desc
);
}
// Softmax operations
bool
getCausalSoftmaxDescriptor
(
size_t
key
,
infiniopCausalSoftmaxDescriptor_t
&
desc
)
{
return
causal_softmax_cache
.
get
(
key
,
desc
);
}
void
putCausalSoftmaxDescriptor
(
size_t
key
,
const
infiniopCausalSoftmaxDescriptor_t
&
desc
)
{
causal_softmax_cache
.
put
(
key
,
desc
);
}
// SwiGLU operations
bool
getSwiGLUDescriptor
(
size_t
key
,
infiniopSwiGLUDescriptor_t
&
desc
)
{
return
swiglu_cache
.
get
(
key
,
desc
);
}
void
putSwiGLUDescriptor
(
size_t
key
,
const
infiniopSwiGLUDescriptor_t
&
desc
)
{
swiglu_cache
.
put
(
key
,
desc
);
}
// Random Sample operations
bool
getRandomSampleDescriptor
(
size_t
key
,
infiniopRandomSampleDescriptor_t
&
desc
)
{
return
random_sample_cache
.
get
(
key
,
desc
);
}
void
putRandomSampleDescriptor
(
size_t
key
,
const
infiniopRandomSampleDescriptor_t
&
desc
)
{
random_sample_cache
.
put
(
key
,
desc
);
}
static
size_t
createDescriptorKey
(
std
::
shared_ptr
<
TensorDesc
>
desc0
,
std
::
shared_ptr
<
TensorDesc
>
desc1
,
std
::
shared_ptr
<
TensorDesc
>
desc2
,
std
::
shared_ptr
<
TensorDesc
>
desc3
,
std
::
shared_ptr
<
TensorDesc
>
desc4
)
{
size_t
seed
=
0
;
if
(
desc0
)
{
hash_combine
(
seed
,
computeTensorDescHash
(
desc0
));
}
if
(
desc1
)
{
hash_combine
(
seed
,
computeTensorDescHash
(
desc1
));
}
if
(
desc2
)
{
hash_combine
(
seed
,
computeTensorDescHash
(
desc2
));
}
if
(
desc3
)
{
hash_combine
(
seed
,
computeTensorDescHash
(
desc3
));
}
if
(
desc4
)
{
hash_combine
(
seed
,
computeTensorDescHash
(
desc4
));
}
return
seed
;
}
};
#endif // CACHE_MANAGER_HPP
src/models/inference_context.cpp
0 → 100644
View file @
d7965f91
#include "inference_context.hpp"
#include "../tensor.hpp"
#include "../utils.hpp"
InferenceContext
::
InferenceContext
(
DeviceResource
*
rsrc
,
CacheManager
*
cache_manager
,
infinirtStream_t
stream
)
:
rsrc
(
rsrc
),
cache_manager
(
cache_manager
),
stream
(
stream
)
{}
void
InferenceContext
::
ensure_workspace
(
size_t
required_size
)
{
if
(
required_size
>
current_workspace_size
)
{
workspace_storage
=
Storage
::
createFromPool
(
required_size
,
rsrc
->
memory_pool
);
current_workspace_size
=
required_size
;
}
}
void
InferenceContext
::
rmsnorm
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
w
,
float
epsilon
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
y
->
tdesc
(),
x
->
tdesc
(),
w
->
tdesc
(),
nullptr
,
nullptr
);
infiniopRMSNormDescriptor_t
desc
;
if
(
!
cache_manager
->
getRMSNormDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateRMSNormDescriptor
(
rsrc
->
handle
,
&
desc
,
y
->
desc
(),
x
->
desc
(),
w
->
desc
(),
epsilon
));
cache_manager
->
putRMSNormDescriptor
(
key
,
desc
);
}
size_t
workspace_size
=
0
;
RUN_INFINI
(
infiniopGetRMSNormWorkspaceSize
(
desc
,
&
workspace_size
));
ensure_workspace
(
workspace_size
);
void
*
workspace
=
workspace_storage
->
memory
();
RUN_INFINI
(
infiniopRMSNorm
(
desc
,
workspace
,
workspace_size
,
y
->
data
(),
x
->
data
(),
w
->
data
(),
stream
));
}
void
InferenceContext
::
gemm
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
TensorDesc
>
c_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
TensorDesc
>
a_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
b
,
std
::
shared_ptr
<
TensorDesc
>
b_desc_overwrite
,
float
alpha
,
float
beta
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
c_desc_overwrite
?
c_desc_overwrite
:
c
->
tdesc
(),
a_desc_overwrite
?
a_desc_overwrite
:
a
->
tdesc
(),
b_desc_overwrite
?
b_desc_overwrite
:
b
->
tdesc
(),
nullptr
,
nullptr
);
infiniopGemmDescriptor_t
desc
;
if
(
!
cache_manager
->
getGemmDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
->
handle
,
&
desc
,
c_desc_overwrite
?
c_desc_overwrite
->
desc
()
:
c
->
desc
(),
a_desc_overwrite
?
a_desc_overwrite
->
desc
()
:
a
->
desc
(),
b_desc_overwrite
?
b_desc_overwrite
->
desc
()
:
b
->
desc
()));
cache_manager
->
putGemmDescriptor
(
key
,
desc
);
}
size_t
workspace_size
=
0
;
RUN_INFINI
(
infiniopGetGemmWorkspaceSize
(
desc
,
&
workspace_size
));
ensure_workspace
(
workspace_size
);
void
*
workspace
=
workspace_storage
->
memory
();
RUN_INFINI
(
infiniopGemm
(
desc
,
workspace
,
workspace_size
,
c
->
data
(),
a
->
data
(),
b
->
data
(),
alpha
,
beta
,
stream
));
}
void
InferenceContext
::
rearrange
(
std
::
shared_ptr
<
Tensor
>
dst
,
std
::
shared_ptr
<
TensorDesc
>
dst_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
src
,
std
::
shared_ptr
<
TensorDesc
>
src_desc_overwrite
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
dst_desc_overwrite
?
dst_desc_overwrite
:
dst
->
tdesc
(),
src_desc_overwrite
?
src_desc_overwrite
:
src
->
tdesc
(),
nullptr
,
nullptr
,
nullptr
);
infiniopRearrangeDescriptor_t
desc
;
if
(
!
cache_manager
->
getRearrangeDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateRearrangeDescriptor
(
rsrc
->
handle
,
&
desc
,
dst_desc_overwrite
?
dst_desc_overwrite
->
desc
()
:
dst
->
desc
(),
src_desc_overwrite
?
src_desc_overwrite
->
desc
()
:
src
->
desc
()));
cache_manager
->
putRearrangeDescriptor
(
key
,
desc
);
}
RUN_INFINI
(
infiniopRearrange
(
desc
,
dst
->
data
(),
src
->
data
(),
stream
));
}
void
InferenceContext
::
rope
(
std
::
shared_ptr
<
Tensor
>
q
,
std
::
shared_ptr
<
Tensor
>
k
,
std
::
shared_ptr
<
Tensor
>
pos
,
std
::
shared_ptr
<
Tensor
>
sin
,
std
::
shared_ptr
<
Tensor
>
cos
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
q
->
tdesc
(),
k
->
tdesc
(),
pos
->
tdesc
(),
sin
->
tdesc
(),
cos
->
tdesc
());
infiniopRoPEDescriptor_t
desc
;
if
(
!
cache_manager
->
getRoPEDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateRoPEDescriptor
(
rsrc
->
handle
,
&
desc
,
q
->
desc
(),
k
->
desc
(),
pos
->
desc
(),
sin
->
desc
(),
cos
->
desc
()));
cache_manager
->
putRoPEDescriptor
(
key
,
desc
);
}
size_t
workspace_size
=
0
;
RUN_INFINI
(
infiniopGetRoPEWorkspaceSize
(
desc
,
&
workspace_size
));
ensure_workspace
(
workspace_size
);
void
*
workspace
=
workspace_storage
->
memory
();
RUN_INFINI
(
infiniopRoPE
(
desc
,
workspace
,
workspace_size
,
q
->
data
(),
k
->
data
(),
pos
->
data
(),
sin
->
data
(),
cos
->
data
(),
stream
));
}
void
InferenceContext
::
causalSoftmax
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
TensorDesc
>
y_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
TensorDesc
>
x_desc_overwrite
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
y_desc_overwrite
?
y_desc_overwrite
:
y
->
tdesc
(),
x_desc_overwrite
?
x_desc_overwrite
:
x
->
tdesc
(),
nullptr
,
nullptr
,
nullptr
);
infiniopCausalSoftmaxDescriptor_t
desc
;
if
(
!
cache_manager
->
getCausalSoftmaxDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateCausalSoftmaxDescriptor
(
rsrc
->
handle
,
&
desc
,
y_desc_overwrite
?
y_desc_overwrite
->
desc
()
:
y
->
desc
(),
x_desc_overwrite
?
x_desc_overwrite
->
desc
()
:
x
->
desc
()));
cache_manager
->
putCausalSoftmaxDescriptor
(
key
,
desc
);
}
size_t
workspace_size
=
0
;
RUN_INFINI
(
infiniopGetCausalSoftmaxWorkspaceSize
(
desc
,
&
workspace_size
));
ensure_workspace
(
workspace_size
);
void
*
workspace
=
workspace_storage
->
memory
();
RUN_INFINI
(
infiniopCausalSoftmax
(
desc
,
workspace
,
workspace_size
,
y
->
data
(),
x
->
data
(),
stream
));
}
void
InferenceContext
::
swiglu
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
Tensor
>
up
,
std
::
shared_ptr
<
Tensor
>
gate
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
out
->
tdesc
(),
up
->
tdesc
(),
gate
->
tdesc
(),
nullptr
,
nullptr
);
infiniopSwiGLUDescriptor_t
desc
;
if
(
!
cache_manager
->
getSwiGLUDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateSwiGLUDescriptor
(
rsrc
->
handle
,
&
desc
,
out
->
desc
(),
up
->
desc
(),
gate
->
desc
()));
cache_manager
->
putSwiGLUDescriptor
(
key
,
desc
);
}
size_t
workspace_size
=
0
;
RUN_INFINI
(
infiniopGetSwiGLUWorkspaceSize
(
desc
,
&
workspace_size
));
ensure_workspace
(
workspace_size
);
void
*
workspace
=
workspace_storage
->
memory
();
RUN_INFINI
(
infiniopSwiGLU
(
desc
,
workspace
,
workspace_size
,
out
->
data
(),
up
->
data
(),
gate
->
data
(),
stream
));
}
void
InferenceContext
::
randomSample
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
TensorDesc
>
out_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
prob
,
std
::
shared_ptr
<
TensorDesc
>
prob_desc_overwrite
,
float
random_val
,
float
top_p
,
uint32_t
top_k
,
float
temperature
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
out_desc_overwrite
?
out_desc_overwrite
:
out
->
tdesc
(),
prob_desc_overwrite
?
prob_desc_overwrite
:
prob
->
tdesc
(),
nullptr
,
nullptr
,
nullptr
);
infiniopRandomSampleDescriptor_t
desc
;
if
(
!
cache_manager
->
getRandomSampleDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateRandomSampleDescriptor
(
rsrc
->
handle
,
&
desc
,
out_desc_overwrite
?
out_desc_overwrite
->
desc
()
:
out
->
desc
(),
prob_desc_overwrite
?
prob_desc_overwrite
->
desc
()
:
prob
->
desc
()));
cache_manager
->
putRandomSampleDescriptor
(
key
,
desc
);
}
size_t
workspace_size
=
0
;
RUN_INFINI
(
infiniopGetRandomSampleWorkspaceSize
(
desc
,
&
workspace_size
));
ensure_workspace
(
workspace_size
);
void
*
workspace
=
workspace_storage
->
memory
();
RUN_INFINI
(
infiniopRandomSample
(
desc
,
workspace
,
workspace_size
,
out
->
data
(),
prob
->
data
(),
random_val
,
top_p
,
top_k
,
temperature
,
stream
));
}
src/models/inference_context.hpp
0 → 100644
View file @
d7965f91
// inference_context.hpp
#pragma once
#include "cache_manager.hpp"
#include "jiuge/jiuge_impl.hpp"
#include "jiuge/jiuge_weight.hpp"
struct
InferenceContext
{
DeviceResource
*
rsrc
;
CacheManager
*
cache_manager
;
infinirtStream_t
stream
;
std
::
shared_ptr
<
Storage
>
workspace_storage
;
size_t
current_workspace_size
=
0
;
InferenceContext
(
DeviceResource
*
rsrc
,
CacheManager
*
cache_manager
,
infinirtStream_t
stream
);
void
ensure_workspace
(
size_t
required_size
);
void
rmsnorm
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
w
,
float
epsilon
);
void
gemm
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
TensorDesc
>
c_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
TensorDesc
>
a_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
b
,
std
::
shared_ptr
<
TensorDesc
>
b_desc_overwrite
,
float
alpha
,
float
beta
);
void
rearrange
(
std
::
shared_ptr
<
Tensor
>
dst
,
std
::
shared_ptr
<
TensorDesc
>
dst_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
src
,
std
::
shared_ptr
<
TensorDesc
>
src_desc_overwrite
);
void
rope
(
std
::
shared_ptr
<
Tensor
>
q
,
std
::
shared_ptr
<
Tensor
>
k
,
std
::
shared_ptr
<
Tensor
>
pos
,
std
::
shared_ptr
<
Tensor
>
sin
,
std
::
shared_ptr
<
Tensor
>
cos
);
void
causalSoftmax
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
TensorDesc
>
y_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
TensorDesc
>
x_desc_overwrite
);
void
swiglu
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
Tensor
>
up
,
std
::
shared_ptr
<
Tensor
>
gate
);
void
randomSample
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
TensorDesc
>
out_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
prob
,
std
::
shared_ptr
<
TensorDesc
>
prob_desc_overwrite
,
float
random_val
,
float
top_p
,
uint32_t
top_k
,
float
temperature
);
};
src/models/jiuge/jiuge.cpp
View file @
d7965f91
...
...
@@ -3,6 +3,7 @@
#include "../../tensor.hpp"
#include "../../utils.hpp"
#include "../inference_context.hpp"
#include "infinicore_infer.h"
#include <random>
...
...
@@ -116,7 +117,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
uint32_t
*
output
)
{
uint32_t
*
output
,
InferenceContext
&
ctx
)
{
auto
nlayer
=
meta
.
nlayer
;
auto
nkvh
=
meta
.
nkvh
/
ndev
;
auto
nh
=
meta
.
nh
/
ndev
;
...
...
@@ -164,239 +165,97 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
dsize
(
dt_logits
)
*
d
,
INFINIRT_MEMCPY_D2D
,
stream
));
}
// Prepare operators and workspace
size_t
workspace_size
=
0
,
temp_size
=
0
;
// attn & mlp rmsnorm
infiniopRMSNormDescriptor_t
desc_norm
;
RUN_INFINI
(
infiniopCreateRMSNormDescriptor
(
rsrc
.
handle
,
&
desc_norm
,
logits_in
->
desc
(),
logits_out
->
desc
(),
rsrc
.
w_attn_norm
[
0
]
->
desc
(),
meta
.
epsilon
));
RUN_INFINI
(
infiniopGetRMSNormWorkspaceSize
(
desc_norm
,
&
workspace_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
// Attention
infiniopGemmDescriptor_t
desc_attn_qkv
,
desc_attn_o
;
infiniopRearrangeDescriptor_t
desc_qkv_bias
;
if
(
has_qkv_bias
)
{
RUN_INFINI
(
infiniopCreateRearrangeDescriptor
(
rsrc
.
handle
,
&
desc_qkv_bias
,
qkv_buf
->
desc
(),
TensorDesc
::
create
(
dt_logits
,
{
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
})
->
desc
()));
}
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
.
handle
,
&
desc_attn_qkv
,
qkv_buf
->
desc
(),
logits_in
->
desc
(),
rsrc
.
w_attn_qkv
[
0
]
->
desc
()));
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
.
handle
,
&
desc_attn_o
,
logits_in
->
desc
(),
o_buf
->
desc
(),
rsrc
.
w_attn_out
[
0
]
->
desc
()));
RUN_INFINI
(
infiniopGetGemmWorkspaceSize
(
desc_attn_qkv
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
RUN_INFINI
(
infiniopGetGemmWorkspaceSize
(
desc_attn_o
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
infiniopRoPEDescriptor_t
desc_rope_q
,
desc_rope_k
;
auto
qkv_desc
=
TensorDesc
::
create
(
dt_logits
,
qkv_buf
->
shape
(),
qkv_buf
->
strides
());
auto
b_attn_qkv_desc
=
TensorDesc
::
create
(
dt_logits
,
{
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
});
auto
o_desc
=
TensorDesc
::
create
(
dt_logits
,
o_buf
->
shape
(),
o_buf
->
strides
());
qkv_buf
->
dimSplit
(
1
,
{
nh
+
nkvh
*
2
,
dh
});
// (ntok, nh + 2 * nkvh, dh)
auto
qkv_buf_q
=
qkv_buf
->
slice
(
1
,
0
,
nh
);
auto
qkv_buf_k
=
qkv_buf
->
slice
(
1
,
nh
,
nkvh
);
RUN_INFINI
(
infiniopCreateRoPEDescriptor
(
rsrc
.
handle
,
&
desc_rope_q
,
qkv_buf_q
->
desc
(),
qkv_buf_q
->
desc
(),
pos_ids_buf
->
desc
(),
rsrc
.
sin_table
->
desc
(),
rsrc
.
cos_table
->
desc
()));
RUN_INFINI
(
infiniopGetRoPEWorkspaceSize
(
desc_rope_q
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
RUN_INFINI
(
infiniopCreateRoPEDescriptor
(
rsrc
.
handle
,
&
desc_rope_k
,
qkv_buf_k
->
desc
(),
qkv_buf_k
->
desc
(),
pos_ids_buf
->
desc
(),
rsrc
.
sin_table
->
desc
(),
rsrc
.
cos_table
->
desc
()));
RUN_INFINI
(
infiniopGetRoPEWorkspaceSize
(
desc_rope_k
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
// attention inner
auto
desc_kv_rearranges
=
std
::
vector
<
infiniopRearrangeDescriptor_t
>
(
nreq
);
auto
desc_q_rearranges
=
std
::
vector
<
infiniopRearrangeDescriptor_t
>
(
nreq
);
auto
desc_qk_gemms
=
std
::
vector
<
infiniopGemmDescriptor_t
>
(
nreq
);
auto
desc_qk_softmaxs
=
std
::
vector
<
infiniopCausalSoftmaxDescriptor_t
>
(
nreq
);
auto
desc_attn_v_gemms
=
std
::
vector
<
infiniopGemmDescriptor_t
>
(
nreq
);
auto
desc_attn_v_rearranges
=
std
::
vector
<
infiniopRearrangeDescriptor_t
>
(
nreq
);
size_t
token_offset
=
0
;
size_t
max_qk_size
=
0
;
size_t
max_seq_len
=
0
;
o_buf
->
dimSplit
(
1
,
{
nh
,
dh
});
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
past_len
=
req_pos
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
total_len
=
past_len
+
seq_len
;
auto
o
=
o_buf
->
slice
({{
0
,
token_offset
,
seq_len
}});
auto
q
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}});
auto
k
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
// auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
// kv cache tensors can share the same descriptor
// [nkvh, dh, total_len]
auto
full_kv
=
kv_caches
[
req
]
->
k
[
idev
][
0
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
auto
cache_kv
=
kv_caches
[
req
]
->
k
[
idev
][
0
]
->
slice
(
0
,
past_len
,
seq_len
);
RUN_INFINI
(
infiniopCreateRearrangeDescriptor
(
rsrc
.
handle
,
&
desc_kv_rearranges
[
req
],
cache_kv
->
desc
(),
k
->
desc
()));
// [nkvh, ngroup, seq_len, dh]
q
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
});
auto
q_t
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
,
ngroup
,
seq_len
,
dh
});
// [seq_len, nkvh, ngroup, dh] -> [nkvh, ngroup, seq_len, dh]
RUN_INFINI
(
infiniopCreateRearrangeDescriptor
(
rsrc
.
handle
,
&
desc_q_rearranges
[
req
],
q_t
->
desc
(),
q
->
desc
()));
// [nkvh, ngroup, seq_len, dh] -> [seq_len, nkvh, ngroup, dh]
auto
attn_v_t
=
q_t
;
auto
attn_v
=
TensorDesc
::
createWithOrder
(
dt_logits
,
{
nkvh
,
ngroup
,
seq_len
,
dh
},
{
1
,
2
,
0
,
3
});
RUN_INFINI
(
infiniopCreateRearrangeDescriptor
(
rsrc
.
handle
,
&
desc_attn_v_rearranges
[
req
],
attn_v
->
desc
(),
attn_v_t
->
desc
()));
q_t
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
,
ngroup
*
seq_len
,
dh
});
auto
qk
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
,
ngroup
*
seq_len
,
total_len
});
max_qk_size
=
std
::
max
(
max_qk_size
,
size_t
(
seq_len
*
total_len
));
max_seq_len
=
std
::
max
(
max_seq_len
,
size_t
(
seq_len
));
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
.
handle
,
&
desc_qk_gemms
[
req
],
qk
->
desc
(),
q_t
->
desc
(),
full_kv
->
desc
()));
RUN_INFINI
(
infiniopGetGemmWorkspaceSize
(
desc_qk_gemms
[
req
],
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
// [nkvh, total_len, dh]
auto
full_v
=
kv_caches
[
req
]
->
v
[
idev
][
0
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
.
handle
,
&
desc_attn_v_gemms
[
req
],
q_t
->
desc
(),
qk
->
desc
(),
full_v
->
desc
()));
RUN_INFINI
(
infiniopGetGemmWorkspaceSize
(
desc_attn_v_gemms
[
req
],
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
qk
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
*
ngroup
,
seq_len
,
total_len
});
RUN_INFINI
(
infiniopCreateCausalSoftmaxDescriptor
(
rsrc
.
handle
,
&
desc_qk_softmaxs
[
req
],
qk
->
desc
(),
qk
->
desc
()));
RUN_INFINI
(
infiniopGetCausalSoftmaxWorkspaceSize
(
desc_qk_softmaxs
[
req
],
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
token_offset
+=
seq_len
;
}
auto
qk_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
,
max_qk_size
},
rsrc
.
memory_pool
);
auto
rearrange_q_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
attn_val_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
,
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
// MLP descriptors
infiniopGemmDescriptor_t
desc_ffn_gate_up
,
desc_ffn_down
;
infiniopSwiGLUDescriptor_t
desc_swiglu
;
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
.
handle
,
&
desc_ffn_gate_up
,
gate_up_buf
->
desc
(),
logits_out
->
desc
(),
rsrc
.
w_ffn_gate_up
[
0
]
->
desc
()));
RUN_INFINI
(
infiniopGetGemmWorkspaceSize
(
desc_ffn_gate_up
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
// MLP buffers
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
auto
up_buf
=
gate_up_buf
->
slice
(
1
,
di
,
di
);
RUN_INFINI
(
infiniopCreateSwiGLUDescriptor
(
rsrc
.
handle
,
&
desc_swiglu
,
gate_buf
->
desc
(),
up_buf
->
desc
(),
gate_buf
->
desc
()));
RUN_INFINI
(
infiniopGetSwiGLUWorkspaceSize
(
desc_swiglu
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
.
handle
,
&
desc_ffn_down
,
logits_in
->
desc
(),
gate_buf
->
desc
(),
rsrc
.
w_ffn_down
[
0
]
->
desc
()));
RUN_INFINI
(
infiniopGetGemmWorkspaceSize
(
desc_ffn_down
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
// Output and sample
infiniopRMSNormDescriptor_t
desc_norm_out
;
RUN_INFINI
(
infiniopCreateRMSNormDescriptor
(
rsrc
.
handle
,
&
desc_norm_out
,
logits_out
->
slice
(
0
,
0
,
1
)
->
desc
(),
logits_out
->
slice
(
0
,
0
,
1
)
->
desc
(),
rsrc
.
w_out_norm
->
desc
(),
meta
.
epsilon
));
RUN_INFINI
(
infiniopGetRMSNormWorkspaceSize
(
desc_norm_out
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
infiniopGemmDescriptor_t
desc_out_embd
;
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
.
handle
,
&
desc_out_embd
,
prob_buf
->
desc
(),
logits_out
->
slice
(
0
,
0
,
nreq
)
->
desc
(),
rsrc
.
w_out_embd
->
desc
()));
RUN_INFINI
(
infiniopGetGemmWorkspaceSize
(
desc_out_embd
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
infiniopRandomSampleDescriptor_t
desc_sample
;
RUN_INFINI
(
infiniopCreateRandomSampleDescriptor
(
rsrc
.
handle
,
&
desc_sample
,
TensorDesc
::
create
(
INFINI_DTYPE_I64
,
{},
{})
->
desc
(),
TensorDesc
::
create
(
dt_logits
,
{
dvoc
},
{
1
})
->
desc
()));
RUN_INFINI
(
infiniopGetRandomSampleWorkspaceSize
(
desc_sample
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
// Allocate workspace
std
::
shared_ptr
<
Storage
>
workspace_storage
=
Storage
::
createFromPool
(
workspace_size
,
rsrc
.
memory_pool
);
void
*
workspace
=
workspace_storage
->
memory
();
auto
result_desc
=
TensorDesc
::
create
(
INFINI_DTYPE_I64
,
{},
{});
auto
prob_desc
=
TensorDesc
::
create
(
dt_logits
,
{
dvoc
},
{
1
});
// Compute
for
(
uint32_t
layer
=
0
;
layer
<
nlayer
;
layer
++
)
{
// 1. Attention
// rms norm
RUN_INFINI
(
infiniopRMSNorm
(
desc_norm
,
workspace
,
workspace_size
,
logits_out
->
data
(),
logits_in
->
data
(),
rsrc
.
w_attn_norm
[
layer
]
->
data
(),
stream
));
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
// qkv_proj
if
(
has_qkv_bias
)
{
RUN_INFINI
(
infiniopRearrange
(
desc_qkv_bias
,
qkv_buf
->
data
(),
rsrc
.
b_attn_qkv
[
layer
]
->
data
(),
stream
));
}
RUN_INFINI
(
infiniopGemm
(
desc_attn_qkv
,
workspace
,
workspace_size
,
qkv_buf
->
data
(),
logits_out
->
data
(),
rsrc
.
w_attn_qkv
[
layer
]
->
data
(),
1.0
,
has_qkv_bias
?
1.0
:
0.0
,
stream
));
ctx
.
rearrange
(
qkv_buf
,
qkv_desc
,
rsrc
.
b_attn_qkv
[
layer
],
b_attn_qkv_desc
);
}
ctx
.
gemm
(
qkv_buf
,
qkv_desc
,
logits_out
,
nullptr
,
rsrc
.
w_attn_qkv
[
layer
],
nullptr
,
1.0
,
has_qkv_bias
?
1.0
:
0.0
);
// rope
RUN_INFINI
(
infiniopRoPE
(
desc_rope_q
,
workspace
,
workspace_size
,
qkv_buf
->
data
(),
qkv_buf
->
data
(),
pos_ids_buf
->
data
(),
rsrc
.
sin_table
->
data
(),
rsrc
.
cos_table
->
data
(),
stream
));
RUN_INFINI
(
infiniopRoPE
(
desc_rope_k
,
workspace
,
workspace_size
,
qkv_buf
->
data
(
nh
*
dh
),
qkv_buf
->
data
(
nh
*
dh
),
pos_ids_buf
->
data
(),
rsrc
.
sin_table
->
data
(),
rsrc
.
cos_table
->
data
(),
stream
));
ctx
.
rope
(
qkv_buf
->
slice
(
1
,
0
,
nh
),
qkv_buf
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
ctx
.
rope
(
qkv_buf
->
slice
(
1
,
nh
,
nkvh
),
qkv_buf
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
size_t
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
past_len
=
req_pos
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
total_len
=
past_len
+
seq_len
;
auto
o
=
o_buf
->
slice
({{
0
,
token_offset
,
seq_len
}});
auto
q
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}});
auto
k
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
auto
v
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
+
nkvh
,
nkvh
}});
auto
qt_rearrange_desc
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
,
ngroup
,
seq_len
,
dh
});
auto
qt_gemm_desc
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
,
ngroup
*
seq_len
,
dh
});
auto
qk_gemm_desc
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
,
ngroup
*
seq_len
,
total_len
});
// self attention
// concat
RUN_INFINI
(
infiniopRearrange
(
desc_kv_rearranges
[
req
],
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
data
(
past_len
*
nkvh
*
dh
),
k
->
data
(),
stream
));
RUN_INFINI
(
infiniopRearrange
(
desc_kv_rearranges
[
req
],
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
data
(
past_len
*
nkvh
*
dh
),
v
->
data
(),
stream
));
ctx
.
rearrange
(
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
nullptr
,
k
,
nullptr
);
ctx
.
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
nullptr
,
v
,
nullptr
);
// qk
RUN_INFINI
(
infiniopRearrange
(
desc_q_rearranges
[
req
],
rearrange_q_buf
->
data
(),
q
->
data
(),
stream
));
RUN_INFINI
(
infiniopGemm
(
desc_qk_gemms
[
req
],
workspace
,
workspace_size
,
qk_buf
->
data
(),
rearrange_q_buf
->
data
(),
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
data
(),
1.
/
sqrt
(
dh
),
0.0
,
stream
));
ctx
.
rearrange
(
rearrange_q_buf
,
qt_rearrange_desc
,
q
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
}),
nullptr
);
ctx
.
gemm
(
qk_buf
,
qk_gemm_desc
,
rearrange_q_buf
,
qt_gemm_desc
,
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
}),
nullptr
,
1.
/
sqrt
(
dh
),
0.0
);
// softmax
RUN_INFINI
(
infiniopCausalSoftmax
(
desc_qk_softmaxs
[
req
],
workspace
,
workspace_size
,
qk_buf
->
data
(),
qk_buf
->
data
(),
stream
));
// attn val
RUN_INFINI
(
infiniopGemm
(
desc_attn_v_gemms
[
req
],
workspace
,
workspace_size
,
attn_val_buf
->
data
(),
qk_buf
->
data
(),
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
data
(),
1.0
,
0.0
,
stream
));
auto
qk_desc
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
*
ngroup
,
seq_len
,
total_len
});
ctx
.
causalSoftmax
(
qk_buf
,
qk_desc
,
qk_buf
,
qk_desc
);
ctx
.
gemm
(
attn_val_buf
,
qt_gemm_desc
,
qk_buf
,
qk_gemm_desc
,
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
}),
nullptr
,
1.0
,
0.0
);
// rearrange attn val
RUN_INFINI
(
infiniopRearrange
(
desc_attn_v_rearranges
[
req
],
o
->
data
(),
attn_val_buf
->
data
(),
stream
));
ctx
.
rearrange
(
o
,
TensorDesc
::
createWithOrder
(
dt_logits
,
{
nkvh
,
ngroup
,
seq_len
,
dh
},
{
1
,
2
,
0
,
3
}),
attn_val_buf
,
qt_rearrange_desc
);
token_offset
+=
seq_len
;
}
// o_proj
RUN_INFINI
(
infiniopGemm
(
desc_attn_o
,
workspace
,
workspace_size
,
logits_in
->
data
(),
o_buf
->
data
()
,
rsrc
.
w_attn_out
[
layer
]
->
data
(),
1.0
,
idev
==
0
?
1.0
:
0.0
,
stream
)
);
// only rank 0 adds residual
ctx
.
gemm
(
logits_in
,
nullptr
,
o_buf
,
o_desc
,
rsrc
.
w_attn_out
[
layer
],
nullptr
,
1.0
,
idev
==
0
?
1.0
:
0.0
);
// only rank 0 adds residual
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
...
...
@@ -407,21 +266,16 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
// 2. FFN
// rms_norm
RUN_INFINI
(
infiniopRMSNorm
(
desc_norm
,
workspace
,
workspace_size
,
logits_out
->
data
(),
logits_in
->
data
(),
rsrc
.
w_ffn_norm
[
layer
]
->
data
(),
stream
));
RUN_INFINI
(
infiniopGemm
(
desc_ffn_gate_up
,
workspace
,
workspace_size
,
gate_up_buf
->
data
(),
logits_out
->
data
(),
rsrc
.
w_ffn_gate_up
[
layer
]
->
data
(),
1.0
,
0.0
,
stream
));
RUN_INFINI
(
infiniopSwiGLU
(
desc_swiglu
,
workspace
,
workspace_size
,
gate_buf
->
data
(),
up_buf
->
data
(),
gate_buf
->
data
(),
stream
));
RUN_INFINI
(
infiniopGemm
(
desc_ffn_down
,
workspace
,
workspace_size
,
logits_in
->
data
(),
gate_buf
->
data
(),
rsrc
.
w_ffn_down
[
layer
]
->
data
(),
1.0
,
idev
==
0
?
1.0
:
0.0
,
stream
));
// only rank 0 adds residual
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_ffn_norm
[
layer
],
meta
.
epsilon
);
ctx
.
gemm
(
gate_up_buf
,
nullptr
,
logits_out
,
nullptr
,
rsrc
.
w_ffn_gate_up
[
layer
],
nullptr
,
1.0
,
0.0
);
ctx
.
swiglu
(
gate_buf
,
up_buf
,
gate_buf
);
ctx
.
gemm
(
logits_in
,
nullptr
,
gate_buf
,
nullptr
,
rsrc
.
w_ffn_down
[
layer
],
nullptr
,
1.0
,
idev
==
0
?
1.0
:
0.0
);
// only rank 0 adds residual
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
...
...
@@ -437,31 +291,24 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
token_offset
+=
seq_len
;
RUN_INFINI
(
infiniopRMSNorm
(
desc_norm_out
,
workspace
,
workspace_size
,
logits_out
->
data
(
req
*
d
),
logits_in
->
data
((
token_offset
-
1
)
*
d
),
rsrc
.
w_out_norm
->
data
(),
stream
));
}
RUN_INFINI
(
infiniopGemm
(
desc_out_embd
,
workspace
,
workspace_size
,
prob_buf
->
data
(),
logits_out
->
data
(),
rsrc
.
w_out_embd
->
data
(),
1.0
,
0.0
,
stream
));
ctx
.
rmsnorm
(
logits_out
->
slice
(
0
,
req
,
1
),
logits_in
->
slice
(
0
,
token_offset
-
1
,
1
),
rsrc
.
w_out_norm
,
meta
.
epsilon
);
}
ctx
.
gemm
(
prob_buf
,
nullptr
,
logits_out
->
slice
(
0
,
0
,
nreq
),
nullptr
,
rsrc
.
w_out_embd
,
nullptr
,
1.0
,
0.0
);
std
::
random_device
_rd
;
std
::
mt19937
gen
(
_rd
());
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
// prob_buf->debug();
RUN_INFINI
(
infiniopRandomSample
(
desc_sample
,
workspace
,
workspace_size
,
result_buf
->
data
(
req
),
prob_buf
->
data
(
req
*
dvoc
),
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
],
stream
));
// result_buf->debug();
ctx
.
randomSample
(
result_buf
->
slice
(
0
,
req
,
1
),
result_desc
,
prob_buf
->
slice
(
0
,
req
,
1
),
prob_desc
,
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
token_offset
+=
seq_len
;
}
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
...
...
@@ -471,30 +318,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
output
[
req
]
=
result_cpu
[
req
];
}
}
// Clean up
infiniopDestroyRMSNormDescriptor
(
desc_norm
);
if
(
has_qkv_bias
)
{
infiniopDestroyRearrangeDescriptor
(
desc_qkv_bias
);
}
infiniopDestroyGemmDescriptor
(
desc_attn_qkv
);
infiniopDestroyGemmDescriptor
(
desc_attn_o
);
infiniopDestroyRoPEDescriptor
(
desc_rope_q
);
infiniopDestroyRoPEDescriptor
(
desc_rope_k
);
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
infiniopDestroyRearrangeDescriptor
(
desc_kv_rearranges
[
req
]);
infiniopDestroyRearrangeDescriptor
(
desc_q_rearranges
[
req
]);
infiniopDestroyGemmDescriptor
(
desc_qk_gemms
[
req
]);
infiniopDestroyCausalSoftmaxDescriptor
(
desc_qk_softmaxs
[
req
]);
infiniopDestroyGemmDescriptor
(
desc_attn_v_gemms
[
req
]);
infiniopDestroyRearrangeDescriptor
(
desc_attn_v_rearranges
[
req
]);
}
infiniopDestroyGemmDescriptor
(
desc_ffn_gate_up
);
infiniopDestroySwiGLUDescriptor
(
desc_swiglu
);
infiniopDestroyGemmDescriptor
(
desc_ffn_down
);
infiniopDestroyRMSNormDescriptor
(
desc_norm_out
);
infiniopDestroyGemmDescriptor
(
desc_out_embd
);
infiniopDestroyRandomSampleDescriptor
(
desc_sample
);
}
__C
void
...
...
@@ -531,6 +354,9 @@ inferBatch(struct JiugeModel *model,
void
launchDevice
(
const
JiugeMeta
&
meta
,
const
JiugeWeights
*
weights
,
DeviceResource
*
rsrc
,
InferState
&
state
,
InferRequest
&
req
,
infiniDevice_t
device
,
int
idev
,
int
ndev
,
int
dev_id
,
infinicclComm_t
comm
)
{
CacheManager
cache_manager
(
100
);
InferenceContext
ctx
(
rsrc
,
&
cache_manager
,
rsrc
->
stream
);
// Create Device Resource
createDeviceResource
(
rsrc
,
&
meta
,
weights
,
device
,
idev
,
ndev
,
dev_id
,
comm
);
{
...
...
@@ -549,7 +375,10 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
break
;
}
inferDeviceBatch
(
meta
,
*
rsrc
,
idev
,
ndev
,
req
.
tokens
,
req
.
ntok
,
req
.
req_lens
,
req
.
nreq
,
req
.
req_pos
,
req
.
kv_caches
,
req
.
temperature
,
req
.
topk
,
req
.
topp
,
req
.
output
);
inferDeviceBatch
(
meta
,
*
rsrc
,
idev
,
ndev
,
req
.
tokens
,
req
.
ntok
,
req
.
req_lens
,
req
.
nreq
,
req
.
req_pos
,
req
.
kv_caches
,
req
.
temperature
,
req
.
topk
,
req
.
topp
,
req
.
output
,
ctx
);
state
.
proceed
=
false
;
lock
.
unlock
();
...
...
src/tensor.hpp
View file @
d7965f91
...
...
@@ -120,6 +120,7 @@ public:
infiniDtype_t
dtype
()
const
;
bool
isContigous
()
const
;
infiniopTensorDescriptor_t
desc
()
const
;
std
::
shared_ptr
<
TensorDesc
>
tdesc
()
const
;
ptrdiff_t
dataOffset
()
const
;
infiniDevice_t
deviceType
()
const
;
int
deviceId
()
const
;
...
...
src/tensor/strorage.cpp
View file @
d7965f91
src/tensor/tensor.cpp
View file @
d7965f91
...
...
@@ -108,6 +108,7 @@ ptrdiff_t Tensor::dataOffset() const {
}
infiniopTensorDescriptor_t
Tensor
::
desc
()
const
{
return
_desc
->
desc
();
}
std
::
shared_ptr
<
TensorDesc
>
Tensor
::
tdesc
()
const
{
return
_desc
;
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
buffer
(
infiniDtype_t
dtype
,
const
std
::
vector
<
size_t
>
&
shape
,
...
...
xmake.lua
View file @
d7965f91
...
...
@@ -12,6 +12,7 @@ target("infinicore_infer")
set_languages
(
"cxx17"
)
set_warnings
(
"all"
,
"error"
)
add_files
(
"src/models/*.cpp"
)
add_files
(
"src/models/*/*.cpp"
)
add_files
(
"src/tensor/*.cpp"
)
add_files
(
"src/allocator/*.cpp"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment