Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
366386d3
"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "88e97c822c988eaa9f8bcbaa1ea5d702ffd7d384"
Commit
366386d3
authored
Jul 22, 2025
by
wooway777
Browse files
issue/21 - removed descriptor overwrite
parent
726f444f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
92 additions
and
129 deletions
+92
-129
src/models/cache_manager.hpp
src/models/cache_manager.hpp
+10
-19
src/models/inference_context.cpp
src/models/inference_context.cpp
+24
-45
src/models/inference_context.hpp
src/models/inference_context.hpp
+13
-10
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+32
-53
src/tensor.hpp
src/tensor.hpp
+2
-1
src/tensor/tensor.cpp
src/tensor/tensor.cpp
+0
-1
src/tensor/transform.cpp
src/tensor/transform.cpp
+11
-0
No files found.
src/models/cache_manager.hpp
View file @
366386d3
...
@@ -21,23 +21,14 @@ inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_
...
@@ -21,23 +21,14 @@ inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_
hash_combine
(
seed
,
static_cast
<
size_t
>
(
value
));
hash_combine
(
seed
,
static_cast
<
size_t
>
(
value
));
}
}
// Specialization for float to handle potential precision issues
inline
void
hash_combine
(
size_t
&
seed
,
float
value
)
{
// Treat float bits as uint32_t for consistent hashing
uint32_t
int_value
;
static_assert
(
sizeof
(
value
)
==
sizeof
(
int_value
),
"Size mismatch"
);
std
::
memcpy
(
&
int_value
,
&
value
,
sizeof
(
value
));
hash_combine
(
seed
,
static_cast
<
size_t
>
(
int_value
));
}
// Helper function to compute hash for tensor descriptors
// Helper function to compute hash for tensor descriptors
inline
size_t
computeTensorDescHash
(
std
::
shared_ptr
<
Tensor
Desc
>
desc
)
{
inline
size_t
computeTensorDescHash
(
std
::
shared_ptr
<
Tensor
>
tensor
)
{
size_t
seed
=
0
;
size_t
seed
=
0
;
hash_combine
(
seed
,
desc
->
dtype
());
hash_combine
(
seed
,
tensor
->
dtype
());
for
(
auto
dim
:
desc
->
shape
())
{
for
(
auto
dim
:
tensor
->
shape
())
{
hash_combine
(
seed
,
dim
);
hash_combine
(
seed
,
dim
);
}
}
for
(
auto
stride
:
desc
->
strides
())
{
for
(
auto
stride
:
tensor
->
strides
())
{
hash_combine
(
seed
,
static_cast
<
size_t
>
(
stride
));
hash_combine
(
seed
,
static_cast
<
size_t
>
(
stride
));
}
}
return
seed
;
return
seed
;
...
@@ -185,7 +176,7 @@ public:
...
@@ -185,7 +176,7 @@ public:
class
CacheManager
{
class
CacheManager
{
private:
private:
const
size_t
DEFAULT_CACHE_CAPACITY
=
1
00
;
const
size_t
DEFAULT_CACHE_CAPACITY
=
1
28
;
LRUDescriptorCache
<
infiniopRMSNormDescriptor_t
>
rms_norm_cache
;
LRUDescriptorCache
<
infiniopRMSNormDescriptor_t
>
rms_norm_cache
;
LRUDescriptorCache
<
infiniopGemmDescriptor_t
>
gemm_cache
;
LRUDescriptorCache
<
infiniopGemmDescriptor_t
>
gemm_cache
;
...
@@ -267,11 +258,11 @@ public:
...
@@ -267,11 +258,11 @@ public:
random_sample_cache
.
put
(
key
,
desc
);
random_sample_cache
.
put
(
key
,
desc
);
}
}
static
size_t
createDescriptorKey
(
std
::
shared_ptr
<
Tensor
Desc
>
desc0
,
static
size_t
createDescriptorKey
(
std
::
shared_ptr
<
Tensor
>
desc0
,
std
::
shared_ptr
<
Tensor
Desc
>
desc1
,
std
::
shared_ptr
<
Tensor
>
desc1
,
std
::
shared_ptr
<
Tensor
Desc
>
desc2
,
std
::
shared_ptr
<
Tensor
>
desc2
,
std
::
shared_ptr
<
Tensor
Desc
>
desc3
,
std
::
shared_ptr
<
Tensor
>
desc3
,
std
::
shared_ptr
<
Tensor
Desc
>
desc4
)
{
std
::
shared_ptr
<
Tensor
>
desc4
)
{
size_t
seed
=
0
;
size_t
seed
=
0
;
if
(
desc0
)
{
if
(
desc0
)
{
hash_combine
(
seed
,
computeTensorDescHash
(
desc0
));
hash_combine
(
seed
,
computeTensorDescHash
(
desc0
));
...
...
src/models/inference_context.cpp
View file @
366386d3
...
@@ -16,7 +16,7 @@ void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
...
@@ -16,7 +16,7 @@ void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
w
,
std
::
shared_ptr
<
Tensor
>
w
,
float
epsilon
)
{
float
epsilon
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
y
->
tdesc
(),
x
->
tdesc
(),
w
->
tdesc
()
,
nullptr
,
nullptr
);
size_t
key
=
CacheManager
::
createDescriptorKey
(
y
,
x
,
w
,
nullptr
,
nullptr
);
infiniopRMSNormDescriptor_t
desc
;
infiniopRMSNormDescriptor_t
desc
;
if
(
!
cache_manager
->
getRMSNormDescriptor
(
key
,
desc
))
{
if
(
!
cache_manager
->
getRMSNormDescriptor
(
key
,
desc
))
{
...
@@ -35,23 +35,16 @@ void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
...
@@ -35,23 +35,16 @@ void InferenceContext::rmsnorm(std::shared_ptr<Tensor> y,
y
->
data
(),
x
->
data
(),
w
->
data
(),
stream
));
y
->
data
(),
x
->
data
(),
w
->
data
(),
stream
));
}
}
void
InferenceContext
::
gemm
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
TensorDesc
>
c_desc_overwrite
,
void
InferenceContext
::
gemm
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
TensorDesc
>
a_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
,
std
::
shared_ptr
<
TensorDesc
>
b_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
)
{
float
alpha
,
float
beta
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
size_t
key
=
CacheManager
::
createDescriptorKey
(
c
,
a
,
b
,
c_desc_overwrite
?
c_desc_overwrite
:
c
->
tdesc
(),
nullptr
,
nullptr
);
a_desc_overwrite
?
a_desc_overwrite
:
a
->
tdesc
(),
b_desc_overwrite
?
b_desc_overwrite
:
b
->
tdesc
(),
nullptr
,
nullptr
);
infiniopGemmDescriptor_t
desc
;
infiniopGemmDescriptor_t
desc
;
if
(
!
cache_manager
->
getGemmDescriptor
(
key
,
desc
))
{
if
(
!
cache_manager
->
getGemmDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
->
handle
,
&
desc
,
c
->
desc
(),
a
->
desc
(),
b
->
desc
()));
rsrc
->
handle
,
&
desc
,
c_desc_overwrite
?
c_desc_overwrite
->
desc
()
:
c
->
desc
(),
a_desc_overwrite
?
a_desc_overwrite
->
desc
()
:
a
->
desc
(),
b_desc_overwrite
?
b_desc_overwrite
->
desc
()
:
b
->
desc
()));
cache_manager
->
putGemmDescriptor
(
key
,
desc
);
cache_manager
->
putGemmDescriptor
(
key
,
desc
);
}
}
...
@@ -65,19 +58,13 @@ void InferenceContext::gemm(std::shared_ptr<Tensor> c, std::shared_ptr<TensorDes
...
@@ -65,19 +58,13 @@ void InferenceContext::gemm(std::shared_ptr<Tensor> c, std::shared_ptr<TensorDes
c
->
data
(),
a
->
data
(),
b
->
data
(),
alpha
,
beta
,
stream
));
c
->
data
(),
a
->
data
(),
b
->
data
(),
alpha
,
beta
,
stream
));
}
}
void
InferenceContext
::
rearrange
(
std
::
shared_ptr
<
Tensor
>
dst
,
std
::
shared_ptr
<
TensorDesc
>
dst_desc_overwrite
,
void
InferenceContext
::
rearrange
(
std
::
shared_ptr
<
Tensor
>
dst
,
std
::
shared_ptr
<
Tensor
>
src
,
std
::
shared_ptr
<
TensorDesc
>
src_desc_overwrite
)
{
std
::
shared_ptr
<
Tensor
>
src
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
size_t
key
=
CacheManager
::
createDescriptorKey
(
dst
,
src
,
nullptr
,
nullptr
,
nullptr
);
dst_desc_overwrite
?
dst_desc_overwrite
:
dst
->
tdesc
(),
src_desc_overwrite
?
src_desc_overwrite
:
src
->
tdesc
(),
nullptr
,
nullptr
,
nullptr
);
infiniopRearrangeDescriptor_t
desc
;
infiniopRearrangeDescriptor_t
desc
;
if
(
!
cache_manager
->
getRearrangeDescriptor
(
key
,
desc
))
{
if
(
!
cache_manager
->
getRearrangeDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateRearrangeDescriptor
(
RUN_INFINI
(
infiniopCreateRearrangeDescriptor
(
rsrc
->
handle
,
&
desc
,
dst
->
desc
(),
src
->
desc
()));
rsrc
->
handle
,
&
desc
,
dst_desc_overwrite
?
dst_desc_overwrite
->
desc
()
:
dst
->
desc
(),
src_desc_overwrite
?
src_desc_overwrite
->
desc
()
:
src
->
desc
()));
cache_manager
->
putRearrangeDescriptor
(
key
,
desc
);
cache_manager
->
putRearrangeDescriptor
(
key
,
desc
);
}
}
...
@@ -93,7 +80,7 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q,
...
@@ -93,7 +80,7 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q,
std
::
shared_ptr
<
Tensor
>
pos
,
std
::
shared_ptr
<
Tensor
>
pos
,
std
::
shared_ptr
<
Tensor
>
sin
,
std
::
shared_ptr
<
Tensor
>
sin
,
std
::
shared_ptr
<
Tensor
>
cos
)
{
std
::
shared_ptr
<
Tensor
>
cos
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
q
->
tdesc
(),
k
->
tdesc
(),
pos
->
tdesc
(),
sin
->
tdesc
(),
cos
->
tdesc
()
);
size_t
key
=
CacheManager
::
createDescriptorKey
(
q
,
k
,
pos
,
sin
,
cos
);
infiniopRoPEDescriptor_t
desc
;
infiniopRoPEDescriptor_t
desc
;
if
(
!
cache_manager
->
getRoPEDescriptor
(
key
,
desc
))
{
if
(
!
cache_manager
->
getRoPEDescriptor
(
key
,
desc
))
{
...
@@ -114,19 +101,14 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q,
...
@@ -114,19 +101,14 @@ void InferenceContext::rope(std::shared_ptr<Tensor> q,
sin
->
data
(),
cos
->
data
(),
stream
));
sin
->
data
(),
cos
->
data
(),
stream
));
}
}
void
InferenceContext
::
causalSoftmax
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
TensorDesc
>
y_desc_overwrite
,
void
InferenceContext
::
causalSoftmax
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
TensorDesc
>
x_desc_overwrite
)
{
std
::
shared_ptr
<
Tensor
>
x
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
size_t
key
=
CacheManager
::
createDescriptorKey
(
y
,
x
,
nullptr
,
nullptr
,
nullptr
);
y_desc_overwrite
?
y_desc_overwrite
:
y
->
tdesc
(),
x_desc_overwrite
?
x_desc_overwrite
:
x
->
tdesc
(),
nullptr
,
nullptr
,
nullptr
);
infiniopCausalSoftmaxDescriptor_t
desc
;
infiniopCausalSoftmaxDescriptor_t
desc
;
if
(
!
cache_manager
->
getCausalSoftmaxDescriptor
(
key
,
desc
))
{
if
(
!
cache_manager
->
getCausalSoftmaxDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateCausalSoftmaxDescriptor
(
RUN_INFINI
(
infiniopCreateCausalSoftmaxDescriptor
(
rsrc
->
handle
,
&
desc
,
rsrc
->
handle
,
&
desc
,
y
->
desc
(),
x
->
desc
()));
y_desc_overwrite
?
y_desc_overwrite
->
desc
()
:
y
->
desc
(),
x_desc_overwrite
?
x_desc_overwrite
->
desc
()
:
x
->
desc
()));
cache_manager
->
putCausalSoftmaxDescriptor
(
key
,
desc
);
cache_manager
->
putCausalSoftmaxDescriptor
(
key
,
desc
);
}
}
...
@@ -139,8 +121,10 @@ void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<
...
@@ -139,8 +121,10 @@ void InferenceContext::causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<
y
->
data
(),
x
->
data
(),
stream
));
y
->
data
(),
x
->
data
(),
stream
));
}
}
void
InferenceContext
::
swiglu
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
Tensor
>
up
,
std
::
shared_ptr
<
Tensor
>
gate
)
{
void
InferenceContext
::
swiglu
(
std
::
shared_ptr
<
Tensor
>
out
,
size_t
key
=
CacheManager
::
createDescriptorKey
(
out
->
tdesc
(),
up
->
tdesc
(),
gate
->
tdesc
(),
nullptr
,
nullptr
);
std
::
shared_ptr
<
Tensor
>
up
,
std
::
shared_ptr
<
Tensor
>
gate
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
out
,
up
,
gate
,
nullptr
,
nullptr
);
infiniopSwiGLUDescriptor_t
desc
;
infiniopSwiGLUDescriptor_t
desc
;
if
(
!
cache_manager
->
getSwiGLUDescriptor
(
key
,
desc
))
{
if
(
!
cache_manager
->
getSwiGLUDescriptor
(
key
,
desc
))
{
...
@@ -158,20 +142,15 @@ void InferenceContext::swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tenso
...
@@ -158,20 +142,15 @@ void InferenceContext::swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tenso
out
->
data
(),
up
->
data
(),
gate
->
data
(),
stream
));
out
->
data
(),
up
->
data
(),
gate
->
data
(),
stream
));
}
}
void
InferenceContext
::
randomSample
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
TensorDesc
>
out_desc_overwrite
,
void
InferenceContext
::
randomSample
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
Tensor
>
prob
,
std
::
shared_ptr
<
TensorDesc
>
prob_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
prob
,
float
random_val
,
float
top_p
,
uint32_t
top_k
,
float
temperature
)
{
float
random_val
,
float
top_p
,
uint32_t
top_k
,
float
temperature
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
size_t
key
=
CacheManager
::
createDescriptorKey
(
out
,
prob
,
nullptr
,
nullptr
,
nullptr
);
out_desc_overwrite
?
out_desc_overwrite
:
out
->
tdesc
(),
prob_desc_overwrite
?
prob_desc_overwrite
:
prob
->
tdesc
(),
nullptr
,
nullptr
,
nullptr
);
infiniopRandomSampleDescriptor_t
desc
;
infiniopRandomSampleDescriptor_t
desc
;
if
(
!
cache_manager
->
getRandomSampleDescriptor
(
key
,
desc
))
{
if
(
!
cache_manager
->
getRandomSampleDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateRandomSampleDescriptor
(
RUN_INFINI
(
infiniopCreateRandomSampleDescriptor
(
rsrc
->
handle
,
&
desc
,
rsrc
->
handle
,
&
desc
,
out
->
desc
(),
prob
->
desc
()));
out_desc_overwrite
?
out_desc_overwrite
->
desc
()
:
out
->
desc
(),
prob_desc_overwrite
?
prob_desc_overwrite
->
desc
()
:
prob
->
desc
()));
cache_manager
->
putRandomSampleDescriptor
(
key
,
desc
);
cache_manager
->
putRandomSampleDescriptor
(
key
,
desc
);
}
}
...
...
src/models/inference_context.hpp
View file @
366386d3
...
@@ -15,25 +15,28 @@ struct InferenceContext {
...
@@ -15,25 +15,28 @@ struct InferenceContext {
InferenceContext
(
DeviceResource
*
rsrc
,
CacheManager
*
cache_manager
,
infinirtStream_t
stream
);
InferenceContext
(
DeviceResource
*
rsrc
,
CacheManager
*
cache_manager
,
infinirtStream_t
stream
);
void
ensure_workspace
(
size_t
required_size
);
void
ensure_workspace
(
size_t
required_size
);
void
rmsnorm
(
std
::
shared_ptr
<
Tensor
>
y
,
void
rmsnorm
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
w
,
std
::
shared_ptr
<
Tensor
>
w
,
float
epsilon
);
float
epsilon
);
void
gemm
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
TensorDesc
>
c_desc_overwrite
,
void
gemm
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
TensorDesc
>
a_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
,
std
::
shared_ptr
<
TensorDesc
>
b_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
);
float
alpha
,
float
beta
);
void
rearrange
(
std
::
shared_ptr
<
Tensor
>
dst
,
std
::
shared_ptr
<
TensorDesc
>
dst_desc_overwrite
,
void
rearrange
(
std
::
shared_ptr
<
Tensor
>
dst
,
std
::
shared_ptr
<
Tensor
>
src
,
std
::
shared_ptr
<
TensorDesc
>
src_desc_overwrite
);
std
::
shared_ptr
<
Tensor
>
src
);
void
rope
(
std
::
shared_ptr
<
Tensor
>
q
,
void
rope
(
std
::
shared_ptr
<
Tensor
>
q
,
std
::
shared_ptr
<
Tensor
>
k
,
std
::
shared_ptr
<
Tensor
>
k
,
std
::
shared_ptr
<
Tensor
>
pos
,
std
::
shared_ptr
<
Tensor
>
pos
,
std
::
shared_ptr
<
Tensor
>
sin
,
std
::
shared_ptr
<
Tensor
>
sin
,
std
::
shared_ptr
<
Tensor
>
cos
);
std
::
shared_ptr
<
Tensor
>
cos
);
void
causalSoftmax
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
TensorDesc
>
y_desc_overwrite
,
void
causalSoftmax
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
TensorDesc
>
x_desc_overwrite
);
std
::
shared_ptr
<
Tensor
>
x
);
void
swiglu
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
Tensor
>
up
,
std
::
shared_ptr
<
Tensor
>
gate
);
void
swiglu
(
std
::
shared_ptr
<
Tensor
>
out
,
void
randomSample
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
TensorDesc
>
out_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
up
,
std
::
shared_ptr
<
Tensor
>
prob
,
std
::
shared_ptr
<
TensorDesc
>
prob_desc_overwrite
,
std
::
shared_ptr
<
Tensor
>
gate
);
void
randomSample
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
Tensor
>
prob
,
float
random_val
,
float
top_p
,
uint32_t
top_k
,
float
temperature
);
float
random_val
,
float
top_p
,
uint32_t
top_k
,
float
temperature
);
};
};
src/models/jiuge/jiuge.cpp
View file @
366386d3
...
@@ -166,15 +166,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -166,15 +166,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
}
// Attention
// Attention
auto
qkv_desc
=
TensorDesc
::
create
(
dt_logits
,
qkv_buf
->
shape
(),
qkv_buf
->
strides
());
auto
b_attn_qkv_desc
=
TensorDesc
::
create
(
dt_logits
,
{
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
});
auto
o_desc
=
TensorDesc
::
create
(
dt_logits
,
o_buf
->
shape
(),
o_buf
->
strides
());
qkv_buf
->
dimSplit
(
1
,
{
nh
+
nkvh
*
2
,
dh
});
// (ntok, nh + 2 * nkvh, dh)
qkv_buf
->
dimSplit
(
1
,
{
nh
+
nkvh
*
2
,
dh
});
// (ntok, nh + 2 * nkvh, dh)
// attention inner
// attention inner
size_t
max_qk_size
=
0
;
size_t
max_qk_size
=
0
;
size_t
max_seq_len
=
0
;
size_t
max_seq_len
=
0
;
o_buf
->
dimSplit
(
1
,
{
nh
,
dh
});
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
past_len
=
req_pos
[
req
];
auto
past_len
=
req_pos
[
req
];
...
@@ -193,24 +188,19 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -193,24 +188,19 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
auto
up_buf
=
gate_up_buf
->
slice
(
1
,
di
,
di
);
auto
up_buf
=
gate_up_buf
->
slice
(
1
,
di
,
di
);
// Output and sample
auto
result_desc
=
TensorDesc
::
create
(
INFINI_DTYPE_I64
,
{},
{});
auto
prob_desc
=
TensorDesc
::
create
(
dt_logits
,
{
dvoc
},
{
1
});
// Compute
// Compute
for
(
uint32_t
layer
=
0
;
layer
<
nlayer
;
layer
++
)
{
for
(
uint32_t
layer
=
0
;
layer
<
nlayer
;
layer
++
)
{
// 1. Attention
// 1. Attention
// rms norm
// rms norm
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
// qkv_proj
// qkv_proj
qkv_buf
->
dimMerge
(
1
,
2
);
if
(
has_qkv_bias
)
{
if
(
has_qkv_bias
)
{
ctx
.
rearrange
(
qkv_buf
,
qkv_desc
,
rsrc
.
b_attn_qkv
[
layer
]
,
b_attn_qkv_desc
);
ctx
.
rearrange
(
qkv_buf
,
rsrc
.
b_attn_qkv
[
layer
]
->
reDesc
({
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
})
);
}
}
ctx
.
gemm
(
qkv_buf
,
qkv_desc
,
ctx
.
gemm
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
has_qkv_bias
?
1.0
:
0.0
);
logits_out
,
nullptr
,
rsrc
.
w_attn_qkv
[
layer
],
nullptr
,
1.0
,
has_qkv_bias
?
1.0
:
0.0
);
// rope
// rope
qkv_buf
->
dimSplit
(
1
,
{
nh
+
nkvh
*
2
,
dh
});
ctx
.
rope
(
qkv_buf
->
slice
(
1
,
0
,
nh
),
qkv_buf
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
ctx
.
rope
(
qkv_buf
->
slice
(
1
,
0
,
nh
),
qkv_buf
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
ctx
.
rope
(
qkv_buf
->
slice
(
1
,
nh
,
nkvh
),
qkv_buf
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
ctx
.
rope
(
qkv_buf
->
slice
(
1
,
nh
,
nkvh
),
qkv_buf
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
...
@@ -219,43 +209,41 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -219,43 +209,41 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
past_len
=
req_pos
[
req
];
auto
past_len
=
req_pos
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
total_len
=
past_len
+
seq_len
;
auto
total_len
=
past_len
+
seq_len
;
auto
o
=
o_buf
->
slice
({{
0
,
token_offset
,
seq_len
}});
auto
o
=
o_buf
->
dimSplit
(
1
,
{
nh
,
dh
})
->
slice
({{
0
,
token_offset
,
seq_len
}});
auto
q
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}});
auto
q
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}});
auto
k
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
auto
k
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
auto
v
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
+
nkvh
,
nkvh
}});
auto
v
=
qkv_buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
+
nkvh
,
nkvh
}});
auto
qt_rearrange_desc
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
,
ngroup
,
seq_len
,
dh
});
auto
qt_gemm_desc
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
,
ngroup
*
seq_len
,
dh
});
auto
qk_gemm_desc
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
,
ngroup
*
seq_len
,
total_len
});
// self attention
// self attention
// concat
// concat
ctx
.
rearrange
(
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
nullptr
,
k
,
nullptr
);
ctx
.
rearrange
(
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
k
);
ctx
.
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
nullptr
,
v
,
nullptr
);
ctx
.
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
// qk
// qk
ctx
.
rearrange
(
rearrange_q_buf
,
qt_rearrange_desc
,
ctx
.
rearrange
(
rearrange_q_buf
->
dimSplit
(
1
,
{
ngroup
,
seq_len
})
,
q
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
})
,
nullptr
);
q
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
}));
ctx
.
gemm
(
qk_buf
,
qk_gemm_desc
,
qk_buf
->
dimSplit
(
1
,
{
seq_len
,
total_len
});
rearrange_q_buf
,
qt_gemm_desc
,
qk_buf
->
dimSplit
(
0
,
{
nkvh
,
ngroup
});
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permut
e
(
{
1
,
2
,
0
}),
nullptr
,
qk_buf
->
dimMerg
e
(
1
,
2
);
1.
/
sqrt
(
dh
),
0.0
);
ctx
.
gemm
(
qk_buf
,
rearrange_q_buf
->
dimMerge
(
1
,
2
),
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
}),
1.
/
sqrt
(
dh
),
0.0
);
// softmax
// softmax
auto
qk_desc
=
TensorDesc
::
create
(
dt_logits
,
{
nkvh
*
ngroup
,
seq_len
,
total_len
});
qk_buf
->
dimSplit
(
1
,
{
ngroup
,
seq_len
});
ctx
.
causalSoftmax
(
qk_buf
,
qk_desc
,
qk_buf
,
qk_desc
);
qk_buf
->
dimMerge
(
0
,
1
);
ctx
.
gemm
(
attn_val_buf
,
qt_gemm_desc
,
ctx
.
causalSoftmax
(
qk_buf
,
qk_buf
);
qk_buf
,
qk_gemm_desc
,
qk_buf
->
dimSplit
(
0
,
{
nkvh
,
ngroup
});
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
}),
nullptr
,
qk_buf
->
dimMerge
(
1
,
2
);
1.0
,
0.0
);
ctx
.
gemm
(
attn_val_buf
,
qk_buf
,
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
}),
1.0
,
0.0
);
qk_buf
->
dimSplit
(
1
,
{
ngroup
,
seq_len
});
qk_buf
->
dimMerge
(
2
,
3
);
qk_buf
->
dimMerge
(
0
,
1
);
// rearrange attn val
// rearrange attn val
ctx
.
rearrange
(
o
,
TensorDesc
::
createWithOrder
(
dt_logits
,
{
nkvh
,
ngroup
,
seq_len
,
dh
},
{
1
,
2
,
0
,
3
}),
attn_val_buf
->
dimSplit
(
0
,
{
nkvh
,
ngroup
});
attn_val_buf
,
qt_rearrange_desc
);
ctx
.
rearrange
(
o
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
}),
attn_val_buf
);
attn_val_buf
->
dimMerge
(
0
,
1
);
token_offset
+=
seq_len
;
token_offset
+=
seq_len
;
}
}
// o_proj
// o_proj
ctx
.
gemm
(
logits_in
,
nullptr
,
ctx
.
gemm
(
logits_in
,
o_buf
->
dimMerge
(
1
,
2
),
rsrc
.
w_attn_out
[
layer
],
1.0
,
idev
==
0
?
1.0
:
0.0
);
// only rank 0 adds residual
o_buf
,
o_desc
,
rsrc
.
w_attn_out
[
layer
],
nullptr
,
1.0
,
idev
==
0
?
1.0
:
0.0
);
// only rank 0 adds residual
// All_reduce if distributed
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
if
(
rsrc
.
comm
!=
nullptr
)
{
...
@@ -267,15 +255,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -267,15 +255,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// 2. FFN
// 2. FFN
// rms_norm
// rms_norm
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_ffn_norm
[
layer
],
meta
.
epsilon
);
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_ffn_norm
[
layer
],
meta
.
epsilon
);
ctx
.
gemm
(
gate_up_buf
,
nullptr
,
ctx
.
gemm
(
gate_up_buf
,
logits_out
,
rsrc
.
w_ffn_gate_up
[
layer
],
1.0
,
0.0
);
logits_out
,
nullptr
,
rsrc
.
w_ffn_gate_up
[
layer
],
nullptr
,
1.0
,
0.0
);
ctx
.
swiglu
(
gate_buf
,
up_buf
,
gate_buf
);
ctx
.
swiglu
(
gate_buf
,
up_buf
,
gate_buf
);
ctx
.
gemm
(
logits_in
,
nullptr
,
ctx
.
gemm
(
logits_in
,
gate_buf
,
rsrc
.
w_ffn_down
[
layer
],
1.0
,
idev
==
0
?
1.0
:
0.0
);
// only rank 0 adds residual
gate_buf
,
nullptr
,
rsrc
.
w_ffn_down
[
layer
],
nullptr
,
1.0
,
idev
==
0
?
1.0
:
0.0
);
// only rank 0 adds residual
// All_reduce if distributed
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
if
(
rsrc
.
comm
!=
nullptr
)
{
...
@@ -296,18 +278,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -296,18 +278,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rsrc
.
w_out_norm
,
rsrc
.
w_out_norm
,
meta
.
epsilon
);
meta
.
epsilon
);
}
}
ctx
.
gemm
(
prob_buf
,
nullptr
,
ctx
.
gemm
(
prob_buf
,
logits_out
->
slice
(
0
,
0
,
nreq
),
rsrc
.
w_out_embd
,
1.0
,
0.0
);
logits_out
->
slice
(
0
,
0
,
nreq
),
nullptr
,
rsrc
.
w_out_embd
,
nullptr
,
1.0
,
0.0
);
std
::
random_device
_rd
;
std
::
random_device
_rd
;
std
::
mt19937
gen
(
_rd
());
std
::
mt19937
gen
(
_rd
());
token_offset
=
0
;
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
auto
seq_len
=
req_lens
[
req
];
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
ctx
.
randomSample
(
result_buf
->
slice
(
0
,
req
,
1
),
result_desc
,
ctx
.
randomSample
(
result_buf
->
reDesc
({},
{})
,
prob_buf
->
slice
(
0
,
req
,
1
),
prob_desc
,
prob_buf
->
reDesc
({
dvoc
},
{
1
})
,
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
token_offset
+=
seq_len
;
token_offset
+=
seq_len
;
}
}
...
@@ -354,7 +333,7 @@ inferBatch(struct JiugeModel *model,
...
@@ -354,7 +333,7 @@ inferBatch(struct JiugeModel *model,
void
launchDevice
(
const
JiugeMeta
&
meta
,
const
JiugeWeights
*
weights
,
DeviceResource
*
rsrc
,
InferState
&
state
,
InferRequest
&
req
,
void
launchDevice
(
const
JiugeMeta
&
meta
,
const
JiugeWeights
*
weights
,
DeviceResource
*
rsrc
,
InferState
&
state
,
InferRequest
&
req
,
infiniDevice_t
device
,
int
idev
,
int
ndev
,
int
dev_id
,
infinicclComm_t
comm
)
{
infiniDevice_t
device
,
int
idev
,
int
ndev
,
int
dev_id
,
infinicclComm_t
comm
)
{
CacheManager
cache_manager
(
100
);
CacheManager
cache_manager
(
256
);
InferenceContext
ctx
(
rsrc
,
&
cache_manager
,
rsrc
->
stream
);
InferenceContext
ctx
(
rsrc
,
&
cache_manager
,
rsrc
->
stream
);
// Create Device Resource
// Create Device Resource
...
...
src/tensor.hpp
View file @
366386d3
...
@@ -78,6 +78,7 @@ public:
...
@@ -78,6 +78,7 @@ public:
void
dimMerge
(
size_t
dim_start
,
size_t
dim_end
);
void
dimMerge
(
size_t
dim_start
,
size_t
dim_end
);
void
dimSplit
(
size_t
dim
,
const
std
::
vector
<
size_t
>
&
dims
);
void
dimSplit
(
size_t
dim
,
const
std
::
vector
<
size_t
>
&
dims
);
void
permute
(
const
std
::
vector
<
size_t
>
&
order
);
void
permute
(
const
std
::
vector
<
size_t
>
&
order
);
void
reDesc
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
);
};
};
class
Tensor
:
public
std
::
enable_shared_from_this
<
Tensor
>
{
class
Tensor
:
public
std
::
enable_shared_from_this
<
Tensor
>
{
...
@@ -110,6 +111,7 @@ public:
...
@@ -110,6 +111,7 @@ public:
std
::
shared_ptr
<
Tensor
>
dimSplit
(
size_t
dim
,
std
::
shared_ptr
<
Tensor
>
dimSplit
(
size_t
dim
,
const
std
::
vector
<
size_t
>
&
dims
);
const
std
::
vector
<
size_t
>
&
dims
);
std
::
shared_ptr
<
Tensor
>
permute
(
const
std
::
vector
<
size_t
>
&
order
);
std
::
shared_ptr
<
Tensor
>
permute
(
const
std
::
vector
<
size_t
>
&
order
);
std
::
shared_ptr
<
Tensor
>
reDesc
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
);
void
*
data
(
ptrdiff_t
offset
=
0
);
void
*
data
(
ptrdiff_t
offset
=
0
);
void
const
*
data
(
ptrdiff_t
offset
=
0
)
const
;
void
const
*
data
(
ptrdiff_t
offset
=
0
)
const
;
void
copyFrom
(
std
::
shared_ptr
<
Tensor
const
>
src
,
infiniopHandle_t
handle
,
void
copyFrom
(
std
::
shared_ptr
<
Tensor
const
>
src
,
infiniopHandle_t
handle
,
...
@@ -120,7 +122,6 @@ public:
...
@@ -120,7 +122,6 @@ public:
infiniDtype_t
dtype
()
const
;
infiniDtype_t
dtype
()
const
;
bool
isContigous
()
const
;
bool
isContigous
()
const
;
infiniopTensorDescriptor_t
desc
()
const
;
infiniopTensorDescriptor_t
desc
()
const
;
std
::
shared_ptr
<
TensorDesc
>
tdesc
()
const
;
ptrdiff_t
dataOffset
()
const
;
ptrdiff_t
dataOffset
()
const
;
infiniDevice_t
deviceType
()
const
;
infiniDevice_t
deviceType
()
const
;
int
deviceId
()
const
;
int
deviceId
()
const
;
...
...
src/tensor/tensor.cpp
View file @
366386d3
...
@@ -108,7 +108,6 @@ ptrdiff_t Tensor::dataOffset() const {
...
@@ -108,7 +108,6 @@ ptrdiff_t Tensor::dataOffset() const {
}
}
infiniopTensorDescriptor_t
Tensor
::
desc
()
const
{
return
_desc
->
desc
();
}
infiniopTensorDescriptor_t
Tensor
::
desc
()
const
{
return
_desc
->
desc
();
}
std
::
shared_ptr
<
TensorDesc
>
Tensor
::
tdesc
()
const
{
return
_desc
;
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
buffer
(
infiniDtype_t
dtype
,
std
::
shared_ptr
<
Tensor
>
Tensor
::
buffer
(
infiniDtype_t
dtype
,
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
size_t
>
&
shape
,
...
...
src/tensor/transform.cpp
View file @
366386d3
...
@@ -114,3 +114,14 @@ std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
...
@@ -114,3 +114,14 @@ std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
this
->
_desc
->
permute
(
order
);
this
->
_desc
->
permute
(
order
);
return
shared_from_this
();
return
shared_from_this
();
}
}
void
TensorDesc
::
reDesc
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
{
this
->
_shape
=
new_shape
;
this
->
_strides
=
new_strides
;
this
->
resetDesc
();
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
reDesc
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
{
this
->
_desc
->
reDesc
(
new_shape
,
new_strides
);
return
shared_from_this
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment