Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
115badb9
Commit
115badb9
authored
Jul 29, 2025
by
wooway777
Browse files
issue/21 - Made InferenceContext thread-local to allow cleaner operator calls.
parent
2a2ddc57
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
29 deletions
+89
-29
src/models/inference_context.hpp
src/models/inference_context.hpp
+58
-1
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+31
-28
No files found.
src/models/inference_context.hpp
View file @
115badb9
// inference_context.hpp
#pragma once
#include "cache_manager.hpp"
#include "jiuge/jiuge_impl.hpp"
#include "jiuge/jiuge_weight.hpp"
#include <cassert>
struct
InferenceContext
{
DeviceResource
*
rsrc
;
...
...
@@ -49,3 +49,60 @@ struct InferenceContext {
float
alpha
,
float
beta
,
std
::
shared_ptr
<
Tensor
>
residual
);
};
namespace
{
thread_local
InferenceContext
*
tls_inference_context
=
nullptr
;
}
inline
InferenceContext
&
getInferenceContext
()
{
assert
(
tls_inference_context
!=
nullptr
&&
"InferenceContext not set for this thread"
);
return
*
tls_inference_context
;
}
inline
void
setInferenceContext
(
InferenceContext
*
ctx
)
{
tls_inference_context
=
ctx
;
}
inline
void
add
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
)
{
getInferenceContext
().
add
(
c
,
a
,
b
);
}
inline
void
rmsnorm
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
w
,
float
epsilon
)
{
getInferenceContext
().
rmsnorm
(
y
,
x
,
w
,
epsilon
);
}
inline
void
gemm
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
)
{
getInferenceContext
().
gemm
(
c
,
a
,
b
,
alpha
,
beta
);
}
inline
void
rearrange
(
std
::
shared_ptr
<
Tensor
>
dst
,
std
::
shared_ptr
<
Tensor
>
src
)
{
getInferenceContext
().
rearrange
(
dst
,
src
);
}
inline
void
rope
(
std
::
shared_ptr
<
Tensor
>
q
,
std
::
shared_ptr
<
Tensor
>
k
,
std
::
shared_ptr
<
Tensor
>
pos
,
std
::
shared_ptr
<
Tensor
>
sin
,
std
::
shared_ptr
<
Tensor
>
cos
)
{
getInferenceContext
().
rope
(
q
,
k
,
pos
,
sin
,
cos
);
}
inline
void
causalSoftmax
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
Tensor
>
x
)
{
getInferenceContext
().
causalSoftmax
(
y
,
x
);
}
inline
void
swiglu
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
Tensor
>
up
,
std
::
shared_ptr
<
Tensor
>
gate
)
{
getInferenceContext
().
swiglu
(
out
,
up
,
gate
);
}
inline
void
randomSample
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
Tensor
>
prob
,
float
random_val
,
float
top_p
,
uint32_t
top_k
,
float
temperature
)
{
getInferenceContext
().
randomSample
(
out
,
prob
,
random_val
,
top_p
,
top_k
,
temperature
);
}
inline
void
linear
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
,
std
::
shared_ptr
<
Tensor
>
residual
)
{
getInferenceContext
().
linear
(
c
,
a
,
b
,
alpha
,
beta
,
residual
);
}
src/models/jiuge/jiuge.cpp
View file @
115badb9
...
...
@@ -117,7 +117,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
uint32_t
*
output
,
InferenceContext
&
ctx
)
{
uint32_t
*
output
)
{
auto
nlayer
=
meta
.
nlayer
;
auto
nkvh
=
meta
.
nkvh
/
ndev
;
auto
nh
=
meta
.
nh
/
ndev
;
...
...
@@ -191,16 +191,16 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
for
(
uint32_t
layer
=
0
;
layer
<
nlayer
;
layer
++
)
{
// 1. Attention
// rms norm
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
// qkv_proj
if
(
has_qkv_bias
)
{
ctx
.
rearrange
(
qkv_buf
,
rsrc
.
b_attn_qkv
[
layer
]
->
view
({
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
}));
rearrange
(
qkv_buf
,
rsrc
.
b_attn_qkv
[
layer
]
->
view
({
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
}));
}
ctx
.
linear
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
0.0
,
has_qkv_bias
?
qkv_buf
:
nullptr
);
linear
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
0.0
,
has_qkv_bias
?
qkv_buf
:
nullptr
);
// rope
auto
qkv_rope
=
qkv_buf
->
viewReshaped
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
ctx
.
rope
(
qkv_rope
->
slice
(
1
,
0
,
nh
),
qkv_rope
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
ctx
.
rope
(
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
qkv_rope
->
slice
(
1
,
0
,
nh
),
qkv_rope
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
size_t
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
...
...
@@ -214,28 +214,28 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// self attention
// concat
ctx
.
rearrange
(
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
k
);
ctx
.
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
rearrange
(
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
k
);
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
// qk
auto
q_rearrange
=
rearrange_q_buf
->
viewReshaped
({
nkvh
,
ngroup
,
seq_len
,
dh
});
ctx
.
rearrange
(
q_rearrange
,
q
);
rearrange
(
q_rearrange
,
q
);
auto
qk_gemm
=
qk_buf
->
viewReshaped
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
ctx
.
linear
(
qk_gemm
,
rearrange_q_buf
,
k_gemm
,
1.
/
sqrt
(
dh
),
0.0
,
nullptr
);
linear
(
qk_gemm
,
rearrange_q_buf
,
k_gemm
,
1.
/
sqrt
(
dh
),
0.0
,
nullptr
);
// softmax
auto
qk_softmax
=
qk_buf
->
viewReshaped
({
nh
,
seq_len
,
total_len
});
ctx
.
causalSoftmax
(
qk_softmax
,
qk_softmax
);
causalSoftmax
(
qk_softmax
,
qk_softmax
);
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
ctx
.
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.0
,
0.0
,
nullptr
);
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.0
,
0.0
,
nullptr
);
// rearrange attn val
auto
attn_val_gemm
=
attn_val_buf
->
viewReshaped
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
ctx
.
rearrange
(
o
,
attn_val_gemm
);
rearrange
(
o
,
attn_val_gemm
);
token_offset
+=
seq_len
;
}
// o_proj
ctx
.
linear
(
logits_in
,
o_buf
,
rsrc
.
w_attn_out
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
);
// only rank 0 adds residual
linear
(
logits_in
,
o_buf
,
rsrc
.
w_attn_out
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
);
// only rank 0 adds residual
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
...
...
@@ -245,10 +245,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
}
// 2. FFN
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_ffn_norm
[
layer
],
meta
.
epsilon
);
ctx
.
linear
(
gate_up_buf
,
logits_out
,
rsrc
.
w_ffn_gate_up
[
layer
],
1.0
,
0.0
,
nullptr
);
ctx
.
swiglu
(
gate_buf
,
up_buf
,
gate_buf
);
ctx
.
linear
(
logits_in
,
gate_buf
,
rsrc
.
w_ffn_down
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
);
// only rank 0 adds residual
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_ffn_norm
[
layer
],
meta
.
epsilon
);
linear
(
gate_up_buf
,
logits_out
,
rsrc
.
w_ffn_gate_up
[
layer
],
1.0
,
0.0
,
nullptr
);
swiglu
(
gate_buf
,
up_buf
,
gate_buf
);
linear
(
logits_in
,
gate_buf
,
rsrc
.
w_ffn_down
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
);
// only rank 0 adds residual
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
...
...
@@ -264,21 +264,21 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
token_offset
+=
seq_len
;
ctx
.
rmsnorm
(
logits_out
->
slice
(
0
,
req
,
1
),
logits_in
->
slice
(
0
,
token_offset
-
1
,
1
),
rsrc
.
w_out_norm
,
meta
.
epsilon
);
rmsnorm
(
logits_out
->
slice
(
0
,
req
,
1
),
logits_in
->
slice
(
0
,
token_offset
-
1
,
1
),
rsrc
.
w_out_norm
,
meta
.
epsilon
);
}
ctx
.
linear
(
prob_buf
,
logits_out
->
slice
(
0
,
0
,
nreq
),
rsrc
.
w_out_embd
,
1.0
,
0.0
,
nullptr
);
linear
(
prob_buf
,
logits_out
->
slice
(
0
,
0
,
nreq
),
rsrc
.
w_out_embd
,
1.0
,
0.0
,
nullptr
);
std
::
random_device
_rd
;
std
::
mt19937
gen
(
_rd
());
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
ctx
.
randomSample
(
result_buf
->
view
({},
{}),
prob_buf
->
view
({
dvoc
},
{
1
}),
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
randomSample
(
result_buf
->
view
({},
{}),
prob_buf
->
view
({
dvoc
},
{
1
}),
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
token_offset
+=
seq_len
;
}
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
...
...
@@ -327,6 +327,9 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
CacheManager
cache_manager
(
256
);
InferenceContext
ctx
(
rsrc
,
&
cache_manager
,
rsrc
->
stream
);
// Set the inference context for this thread
setInferenceContext
(
&
ctx
);
// Create Device Resource
createDeviceResource
(
rsrc
,
&
meta
,
weights
,
device
,
idev
,
ndev
,
dev_id
,
comm
);
{
...
...
@@ -347,8 +350,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
inferDeviceBatch
(
meta
,
*
rsrc
,
idev
,
ndev
,
req
.
tokens
,
req
.
ntok
,
req
.
req_lens
,
req
.
nreq
,
req
.
req_pos
,
req
.
kv_caches
,
req
.
temperature
,
req
.
topk
,
req
.
topp
,
req
.
output
,
ctx
);
req
.
temperature
,
req
.
topk
,
req
.
topp
,
req
.
output
);
state
.
proceed
=
false
;
lock
.
unlock
();
...
...
@@ -357,6 +359,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
// Clean-Up
releaseDeviceResource
(
*
rsrc
);
setInferenceContext
(
nullptr
);
// Clear the context when done
}
JiugeModel
::
JiugeModel
(
const
JiugeMeta
*
_meta
,
const
JiugeWeights
*
weights
,
infiniDevice_t
device_
,
std
::
vector
<
int
>
device_ids
)
:
meta
(
*
_meta
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment