Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
ee495225
Commit
ee495225
authored
Aug 14, 2025
by
PanZezhong
Browse files
add perplexity test
parent
07aa6990
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
326 additions
and
161 deletions
+326
-161
include/infinicore_infer/models/jiuge.h
include/infinicore_infer/models/jiuge.h
+16
-1
scripts/jiuge.py
scripts/jiuge.py
+57
-1
scripts/jiuge_ppl.py
scripts/jiuge_ppl.py
+107
-0
scripts/libinfinicore_infer.py
scripts/libinfinicore_infer.py
+12
-0
scripts/test_ppl.py
scripts/test_ppl.py
+62
-0
scripts/test_server.py
scripts/test_server.py
+0
-130
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+71
-29
src/models/jiuge/jiuge_impl.hpp
src/models/jiuge/jiuge_impl.hpp
+1
-0
No files found.
include/infinicore_infer/models/jiuge.h
View file @
ee495225
...
...
@@ -75,7 +75,7 @@ __C __export void
dropKVCache
(
const
struct
JiugeModel
*
,
struct
KVCache
*
);
/// @brief 批次推理一轮
/// @brief 批次推理一轮
,并采样出新的 token
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
...
...
@@ -94,4 +94,19 @@ inferBatch(struct JiugeModel *,
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
uint32_t
*
output
);
/// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C
__export
void
forwardBatch
(
struct
JiugeModel
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
void
*
logits
);
#endif
scripts/jiuge.py
View file @
ee495225
from
typing
import
List
from
typing
import
List
,
Sequence
from
sympy
import
true
from
libinfinicore_infer
import
(
JiugeMetaCStruct
,
JiugeWeightsCStruct
,
...
...
@@ -10,6 +12,7 @@ from libinfinicore_infer import (
create_kv_cache
,
drop_kv_cache
,
infer_batch
,
forward_batch
,
)
from
infer_task
import
InferTask
,
KVCache
...
...
@@ -582,6 +585,59 @@ class JiugeForCauslLM:
infer_task
.
_kv_cache
.
drop
(
self
)
return
output_content
,
avg_time
def
perplexity
(
self
,
test_sequences
:
List
[
Sequence
[
int
]],
batch_size
=
10
):
tasks
=
[
InferTask
(
i
,
[],
self
.
max_context_len
(),
1.0
,
1
,
1.0
,
self
.
eos_token_id
)
for
i
in
range
(
batch_size
)
]
kv_caches
=
[
KVCache
(
self
)
for
_
in
range
(
batch_size
)]
nll
=
0.0
total_len
=
0
for
i
in
range
(
0
,
len
(
test_sequences
),
batch_size
):
batch_id
=
0
true_tokens
=
[]
while
batch_id
<
batch_size
and
batch_id
+
i
<
len
(
test_sequences
):
input_tokens
=
test_sequences
[
i
+
batch_id
][:
-
1
]
true_tokens
.
extend
(
test_sequences
[
i
+
batch_id
][
1
:])
tasks
[
batch_id
].
tokens
=
input_tokens
tasks
[
batch_id
].
bind_kvcache
(
kv_caches
[
batch_id
])
batch_id
+=
1
batch_inputs
=
JiugeBatchedTask
(
tasks
[:
batch_id
])
logits
=
torch
.
zeros
(
(
batch_inputs
.
ntok
,
self
.
meta
.
dvoc
),
dtype
=
self
.
meta
.
torch_dtype_logits
)
forward_batch
(
self
.
model_instance
,
batch_inputs
.
tokens
,
batch_inputs
.
ntok
,
batch_inputs
.
req_lens
,
batch_inputs
.
nreq
,
batch_inputs
.
req_pos
,
batch_inputs
.
kv_caches
,
logits
.
data_ptr
(),
)
logits
=
logits
.
float
()
token_ids
=
torch
.
tensor
(
true_tokens
,
dtype
=
torch
.
int64
)
# [ntok,]
log_probs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
# (ntok, vocab)
token_logprobs
=
log_probs
[
torch
.
arange
(
batch_inputs
.
ntok
),
token_ids
]
# (ntok,)
start
=
0
for
l
in
batch_inputs
.
req_lens_list
:
nll
+=
-
token_logprobs
[
start
:
start
+
l
].
sum
().
item
()
start
+=
l
total_len
+=
token_logprobs
.
numel
()
for
task
in
tasks
:
task
.
release_kvcache
()
return
math
.
exp
(
nll
/
total_len
)
def
destroy_model_instance
(
self
):
destroy_jiuge_model
(
self
.
model_instance
)
print
(
"Model destroyed"
)
...
...
scripts/jiuge_ppl.py
0 → 100644
View file @
ee495225
import
torch
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
from
datasets
import
load_dataset
from
jiuge
import
JiugeForCauslLM
from
libinfinicore_infer
import
DeviceType
DEVICE_TYPE_MAP
=
{
"cpu"
:
DeviceType
.
DEVICE_TYPE_CPU
,
"nvidia"
:
DeviceType
.
DEVICE_TYPE_NVIDIA
,
"cambricon"
:
DeviceType
.
DEVICE_TYPE_CAMBRICON
,
"ascend"
:
DeviceType
.
DEVICE_TYPE_ASCEND
,
"metax"
:
DeviceType
.
DEVICE_TYPE_METAX
,
"moore"
:
DeviceType
.
DEVICE_TYPE_MOORE
,
}
TORCH_DEVICE_TYPE_MAP
=
{
"cpu"
:
"cpu"
,
"nvidia"
:
"cuda"
,
"cambricon"
:
"mlu"
,
"ascend"
:
"npu"
,
"metax"
:
"cuda"
,
"moore"
:
"cuda"
,
}
def
test_torch
(
input_ids_list
,
device_
):
device
=
TORCH_DEVICE_TYPE_MAP
[
device_
]
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
).
to
(
device
)
model
.
eval
()
total_neg_log_likelihood
=
0
total_tokens
=
0
with
torch
.
no_grad
():
for
input_ids
in
input_ids_list
:
input_ids
=
torch
.
tensor
(
input_ids
,
device
=
device
)
# shift inputs and labels
inputs
=
input_ids
[:
-
1
].
unsqueeze
(
0
)
# [1, seq_len-1]
labels
=
input_ids
[
1
:].
unsqueeze
(
0
)
# [1, seq_len-1]
outputs
=
model
(
inputs
,
use_cache
=
False
)
logits
=
outputs
.
logits
# [1, seq_len-1, vocab_size]
log_probs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
# gather log probs of true tokens
true_token_log_probs
=
log_probs
.
gather
(
dim
=-
1
,
index
=
labels
.
unsqueeze
(
-
1
)
).
squeeze
(
-
1
)
total_neg_log_likelihood
+=
-
true_token_log_probs
.
sum
().
item
()
total_tokens
+=
labels
.
numel
()
perplexity
=
torch
.
exp
(
torch
.
tensor
(
total_neg_log_likelihood
/
total_tokens
))
return
perplexity
def
test_infinicore
(
input_ids_list
,
device_
,
ndev_
):
device
=
DEVICE_TYPE_MAP
[
device_
]
model
=
JiugeForCauslLM
(
model_path
,
device
,
max_tokens
=
len
(
input_ids_list
[
0
]),
ndev
=
ndev_
)
perplexity
=
model
.
perplexity
(
input_ids_list
)
model
.
destroy_model_instance
()
return
perplexity
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model-path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--dev"
,
type
=
str
,
default
=
"cpu"
,
choices
=
DEVICE_TYPE_MAP
.
keys
()
)
parser
.
add_argument
(
"--ndev"
,
type
=
int
,
default
=
1
,
help
=
"Number of devices to use (default: 1)"
,
)
args
=
parser
.
parse_args
()
seq_len
=
512
model_path
=
args
.
model_path
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
dataset
=
load_dataset
(
"wikitext"
,
"wikitext-2-raw-v1"
,
split
=
"test"
)
texts
=
dataset
[
"text"
]
texts
=
[
t
for
t
in
texts
if
len
(
t
.
strip
())
>
0
]
input_ids_list
=
[]
for
text
in
texts
:
ids
=
tokenizer
.
encode
(
text
)
# split long sequences into chunks
for
i
in
range
(
0
,
len
(
ids
)
-
seq_len
+
1
,
seq_len
):
input_ids_list
.
append
(
ids
[
i
:
i
+
seq_len
])
perplexity
=
test_infinicore
(
input_ids_list
,
args
.
dev
,
args
.
ndev
)
print
(
f
"InfiniCore Perplexity:
{
perplexity
:.
2
f
}
"
)
if
args
.
ndev
==
1
:
# Todo: support multi-device testing with torch
perplexity
=
test_torch
(
input_ids_list
,
args
.
dev
)
print
(
f
"Torch Perplexity:
{
perplexity
.
item
():.
2
f
}
"
)
scripts/libinfinicore_infer.py
View file @
ee495225
...
...
@@ -112,6 +112,17 @@ def __open_library__():
POINTER
(
c_float
),
# float topp
POINTER
(
c_uint
),
# unsigned int *output
]
lib
.
forwardBatch
.
restype
=
None
lib
.
forwardBatch
.
argtypes
=
[
POINTER
(
JiugeModelCSruct
),
# struct JiugeModel const *
POINTER
(
c_uint
),
# unsigned int const *tokens
c_uint
,
# unsigned int ntok
POINTER
(
c_uint
),
# unsigned int const *req_lens
c_uint
,
# unsigned int nreq
POINTER
(
c_uint
),
# unsigned int const *req_pos
POINTER
(
POINTER
(
KVCacheCStruct
)),
# struct KVCache **kv_caches
c_void_p
,
# void *logits
]
return
lib
...
...
@@ -123,3 +134,4 @@ destroy_jiuge_model = LIB.destroyJiugeModel
create_kv_cache
=
LIB
.
createKVCache
drop_kv_cache
=
LIB
.
dropKVCache
infer_batch
=
LIB
.
inferBatch
forward_batch
=
LIB
.
forwardBatch
scripts/test_ppl.py
0 → 100644
View file @
ee495225
import
math
import
requests
from
datasets
import
load_dataset
from
tqdm
import
tqdm
from
transformers
import
AutoTokenizer
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model-path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--endpoint"
,
type
=
str
,
default
=
"/completions"
)
parser
.
add_argument
(
"--chunk"
,
type
=
int
,
default
=
512
)
args
=
parser
.
parse_args
()
API_URL
=
"http://localhost:"
+
str
(
args
.
port
)
+
args
.
endpoint
CHUNK_SIZE
=
args
.
chunk
dataset
=
load_dataset
(
"wikitext"
,
"wikitext-2-raw-v1"
,
split
=
"test"
)
# Local tokenizer used for chunking
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_path
)
total_neg_log_likelihood
=
0.0
total_tokens
=
0
for
example
in
tqdm
(
dataset
,
desc
=
"Evaluating PPL"
):
text
=
example
[
"text"
].
strip
()
if
not
text
:
continue
# endcode, chunk and decode
tokens
=
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
)
for
i
in
range
(
0
,
len
(
tokens
),
CHUNK_SIZE
):
chunk_tokens
=
tokens
[
i
:
min
(
i
+
CHUNK_SIZE
,
len
(
tokens
))]
chunk_text
=
tokenizer
.
decode
(
chunk_tokens
)
resp
=
requests
.
post
(
API_URL
,
headers
=
{
"Content-Type"
:
"application/json"
},
json
=
{
"model"
:
""
,
"prompt"
:
chunk_text
,
"max_tokens"
:
0
,
"temperature"
:
1.0
,
"echo"
:
True
,
"logprobs"
:
0
,
},
).
json
()
logprobs
=
resp
[
"choices"
][
0
][
"logprobs"
][
"token_logprobs"
]
# skip first token's None
valid_logprobs
=
[
lp
for
lp
in
logprobs
[
1
:]
if
lp
is
not
None
]
total_neg_log_likelihood
+=
-
sum
(
valid_logprobs
)
total_tokens
+=
len
(
valid_logprobs
)
# ==== Compute final PPL ====
ppl
=
math
.
exp
(
total_neg_log_likelihood
/
total_tokens
)
print
(
f
"Perplexity:
{
ppl
:.
4
f
}
"
)
scripts/test_server.py
deleted
100644 → 0
View file @
07aa6990
import
requests
import
json
import
time
from
concurrent.futures
import
ThreadPoolExecutor
,
as_completed
API_URL
=
"http://localhost:8000/chat/completions"
MODEL
=
"FM9G-7B"
PROMPT
=
[
"山东最高的山是?"
,
"给我讲个故事"
]
CONCURRENCY
=
10
# 并发用户数量
def
single_run
(
user_id
):
payload
=
{
"model"
:
MODEL
,
"messages"
:
[{
"role"
:
"user"
,
"content"
:
PROMPT
[
user_id
%
len
(
PROMPT
)]}],
"max_tokens"
:
512
,
"stream"
:
True
}
headers
=
{
'Content-Type'
:
'application/json'
,
'Accept'
:
'application/json'
}
print
(
f
"[User
{
user_id
}
] Sending request..."
)
start
=
time
.
perf_counter
()
resp
=
requests
.
post
(
API_URL
,
headers
=
headers
,
json
=
payload
,
stream
=
True
)
resp
.
raise_for_status
()
ttfb
=
resp
.
elapsed
.
total_seconds
()
# HTTP header 到达时间
header_received
=
time
.
perf_counter
()
if
resp
.
encoding
is
None
:
resp
.
encoding
=
'utf-8'
tokens
=
0
chunks
=
[]
for
line
in
resp
.
iter_lines
(
decode_unicode
=
True
):
if
not
line
or
line
.
strip
()
==
"[DONE]"
:
continue
s
=
line
.
strip
()
if
s
.
startswith
(
"data:"
):
s
=
s
[
len
(
"data:"
):].
strip
()
try
:
data
=
json
.
loads
(
s
)
except
json
.
JSONDecodeError
:
continue
text
=
data
.
get
(
"choices"
,
[{}])[
0
].
get
(
"delta"
,
{}).
get
(
"content"
)
if
text
:
chunks
.
append
(
text
)
tokens
+=
1
stream_done
=
time
.
perf_counter
()
# 时间计算
stream_time
=
stream_done
-
header_received
total_time
=
stream_done
-
start
time_per_token_ms
=
(
stream_time
/
tokens
*
1000
)
if
tokens
else
float
(
'inf'
)
tps
=
tokens
/
stream_time
if
stream_time
>
0
else
0
return
{
"user"
:
user_id
,
"ttfb"
:
ttfb
,
"stream_time"
:
stream_time
,
"total_time"
:
total_time
,
"tokens"
:
tokens
,
"time_per_token_ms"
:
time_per_token_ms
,
"tps"
:
tps
,
"chunks"
:
chunks
}
def
main
():
worst
=
None
worst_stream
=
-
1.0
best_stream
=
float
(
'inf'
)
results
=
[]
with
ThreadPoolExecutor
(
max_workers
=
CONCURRENCY
)
as
e
:
futures
=
[
e
.
submit
(
single_run
,
uid
)
for
uid
in
range
(
CONCURRENCY
)]
for
future
in
as_completed
(
futures
):
r
=
future
.
result
()
results
.
append
(
r
)
print
(
f
"User
{
r
[
'user'
]
}
→ TTFB =
{
r
[
'ttfb'
]:.
3
f
}
s, latency =
{
r
[
'stream_time'
]:.
3
f
}
s, "
f
"tokens =
{
r
[
'tokens'
]
}
, time/token =
{
r
[
'time_per_token_ms'
]:.
2
f
}
ms, "
f
"TPS =
{
r
[
'tps'
]:.
1
f
}
tok/s"
)
if
r
[
'stream_time'
]
>
worst_stream
:
worst_stream
=
r
[
'stream_time'
]
worst
=
r
if
r
[
'stream_time'
]
<
best_stream
:
best_stream
=
r
[
'stream_time'
]
best
=
r
# Sort results by user ID
results
.
sort
(
key
=
lambda
x
:
x
[
"user"
])
with
open
(
"responses.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
fw
:
for
r
in
results
:
fw
.
write
(
f
"[User
{
r
[
'user'
]
}
]
\n
"
)
text
=
""
.
join
(
r
[
"chunks"
])
# fixed = text.encode('latin-1').decode('utf-8')
fixed
=
text
fw
.
write
(
fixed
)
fw
.
write
(
"
\n\n
"
)
n
=
CONCURRENCY
avg_ttfb
=
sum
(
r
[
'ttfb'
]
for
r
in
results
)
/
n
avg_token
=
sum
(
r
[
'tokens'
]
for
r
in
results
)
/
n
avg_stream
=
sum
(
r
[
'stream_time'
]
for
r
in
results
)
/
n
avg_tps
=
sum
(
r
[
'tps'
]
for
r
in
results
)
/
n
avg_time_per_token
=
sum
(
r
[
'time_per_token_ms'
]
for
r
in
results
)
/
n
print
(
f
"
\n
✅ All
{
n
}
requests completed."
)
print
(
f
"Averages → TTFB =
{
avg_ttfb
:.
3
f
}
s, latency =
{
avg_stream
:.
3
f
}
s, "
f
"tokens =
{
avg_token
:.
1
f
}
, TPS =
{
avg_tps
:.
1
f
}
tok/s, time/token =
{
avg_time_per_token
:.
2
f
}
ms"
)
if
best
:
print
(
"
\n
Fastest user:"
)
print
(
f
"User
{
best
[
'user'
]
}
→ latency =
{
best
[
'stream_time'
]:.
3
f
}
s, "
f
"tokens =
{
best
[
'tokens'
]
}
, TPS =
{
best
[
'tps'
]:.
1
f
}
tok/s, "
f
"time/token =
{
best
[
'time_per_token_ms'
]:.
2
f
}
ms"
)
if
worst
:
print
(
"
\n
Slowest user:"
)
print
(
f
"User
{
worst
[
'user'
]
}
→ latency =
{
worst
[
'stream_time'
]:.
3
f
}
s, "
f
"tokens =
{
worst
[
'tokens'
]
}
, TPS =
{
worst
[
'tps'
]:.
1
f
}
tok/s, "
f
"time/token =
{
worst
[
'time_per_token_ms'
]:.
2
f
}
ms"
)
if
__name__
==
"__main__"
:
main
()
src/models/jiuge/jiuge.cpp
View file @
ee495225
...
...
@@ -117,7 +117,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
uint32_t
*
output
)
{
uint32_t
*
output
,
void
*
last_logits
)
{
auto
nlayer
=
meta
.
nlayer
;
auto
nkvh
=
meta
.
nkvh
/
ndev
;
auto
nh
=
meta
.
nh
/
ndev
;
...
...
@@ -220,12 +220,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange
(
q_rearrange
,
q
);
auto
qk_gemm
=
qk_buf
->
view
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
linear
(
qk_gemm
,
rearrange_q_buf
,
k_gemm
,
1.
/
sqrt
(
dh
),
0.
0
,
nullptr
,
nullptr
);
linear
(
qk_gemm
,
rearrange_q_buf
,
k_gemm
,
1.
f
/
float
(
sqrt
(
dh
)
)
,
0.
f
,
nullptr
,
nullptr
);
// softmax
auto
qk_softmax
=
qk_buf
->
view
({
nh
,
seq_len
,
total_len
});
causalSoftmax
(
qk_softmax
,
qk_softmax
);
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.
0
,
0.
0
,
nullptr
,
nullptr
);
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.
f
,
0.
f
,
nullptr
,
nullptr
);
// rearrange attn val
rearrange
(
o
,
attn_val_gemm
);
...
...
@@ -258,32 +258,41 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
// Sample and Output
if
(
idev
==
0
)
{
size_t
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
token_offset
+=
seq_len
;
rmsnorm
(
logits_out
->
slice
(
0
,
req
,
1
),
logits_in
->
slice
(
0
,
token_offset
-
1
,
1
),
rsrc
.
w_out_norm
,
meta
.
epsilon
);
}
linear
(
prob_buf
,
logits_out
->
slice
(
0
,
0
,
nreq
),
rsrc
.
w_out_embd
,
1.0
,
0.0
,
nullptr
,
nullptr
);
std
::
random_device
_rd
;
std
::
mt19937
gen
(
_rd
());
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
randomSample
(
result_buf
->
memShare
({},
result_buf
->
dtype
()),
prob_buf
->
view_as
({
dvoc
},
{
1
}),
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
token_offset
+=
seq_len
;
if
(
last_logits
!=
nullptr
)
{
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_out_norm
,
meta
.
epsilon
);
auto
last_logits_buf
=
Tensor
::
buffer
(
dt_logits
,
{
ntok
,
dvoc
},
rsrc
.
memory_pool
);
linear
(
last_logits_buf
,
logits_out
,
rsrc
.
w_out_embd
,
1.0
,
0.0
,
nullptr
,
nullptr
);
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
RUN_INFINI
(
infinirtMemcpy
(
last_logits
,
last_logits_buf
->
data
(),
dsize
(
dt_logits
)
*
ntok
*
dvoc
,
INFINIRT_MEMCPY_D2H
));
}
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
RUN_INFINI
(
infinirtMemcpy
(
result_cpu
.
data
(),
result_buf
->
data
(),
sizeof
(
int64_t
)
*
nreq
,
INFINIRT_MEMCPY_D2H
));
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
output
[
req
]
=
result_cpu
[
req
];
if
(
output
!=
nullptr
)
{
size_t
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
token_offset
+=
seq_len
;
rmsnorm
(
logits_out
->
slice
(
0
,
req
,
1
),
logits_in
->
slice
(
0
,
token_offset
-
1
,
1
),
rsrc
.
w_out_norm
,
meta
.
epsilon
);
}
linear
(
prob_buf
,
logits_out
->
slice
(
0
,
0
,
nreq
),
rsrc
.
w_out_embd
,
1.0
,
0.0
,
nullptr
,
nullptr
);
std
::
random_device
_rd
;
std
::
mt19937
gen
(
_rd
());
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
randomSample
(
result_buf
->
memShare
({},
result_buf
->
dtype
()),
prob_buf
->
view_as
({
dvoc
},
{
1
}),
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
token_offset
+=
seq_len
;
}
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
RUN_INFINI
(
infinirtMemcpy
(
result_cpu
.
data
(),
result_buf
->
data
(),
sizeof
(
int64_t
)
*
nreq
,
INFINIRT_MEMCPY_D2H
));
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
output
[
req
]
=
uint32_t
(
result_cpu
[
req
]);
}
}
}
}
...
...
@@ -302,6 +311,7 @@ inferBatch(struct JiugeModel *model,
model
->
req
.
req_pos
=
req_pos
;
model
->
req
.
kv_caches
=
kv_caches
;
model
->
req
.
output
=
output
;
model
->
req
.
logits
=
nullptr
;
model
->
req
.
temperature
=
temperature
;
model
->
req
.
topk
=
topk
;
model
->
req
.
topp
=
topp
;
...
...
@@ -320,6 +330,38 @@ inferBatch(struct JiugeModel *model,
}
}
__C
void
forwardBatch
(
struct
JiugeModel
*
model
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
void
*
logits
)
{
model
->
req
.
tokens
=
tokens
;
model
->
req
.
ntok
=
ntok
;
model
->
req
.
req_lens
=
req_lens
;
model
->
req
.
nreq
=
nreq
;
model
->
req
.
req_pos
=
req_pos
;
model
->
req
.
kv_caches
=
kv_caches
;
model
->
req
.
output
=
nullptr
;
model
->
req
.
logits
=
logits
;
model
->
req
.
temperature
=
nullptr
;
model
->
req
.
topk
=
nullptr
;
model
->
req
.
topp
=
nullptr
;
for
(
size_t
idev
=
0
;
idev
<
model
->
dev_ids
.
size
();
idev
++
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
model
->
states
[
idev
].
mtx
);
model
->
states
[
idev
].
proceed
=
true
;
lock
.
unlock
();
model
->
states
[
idev
].
cv_start
.
notify_one
();
}
for
(
size_t
i
=
model
->
dev_ids
.
size
();
i
>
0
;
i
--
)
{
auto
idev
=
i
-
1
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
model
->
states
[
idev
].
mtx
);
model
->
states
[
idev
].
cv_done
.
wait
(
lock
,
[
&
]
{
return
!
(
model
->
states
[
idev
].
proceed
);
});
lock
.
unlock
();
}
}
void
launchDevice
(
const
JiugeMeta
&
meta
,
const
JiugeWeights
*
weights
,
DeviceResource
*
rsrc
,
InferState
&
state
,
InferRequest
&
req
,
infiniDevice_t
device
,
int
idev
,
int
ndev
,
int
dev_id
,
infinicclComm_t
comm
)
{
CacheManager
cache_manager
(
100
);
...
...
@@ -348,7 +390,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
inferDeviceBatch
(
meta
,
*
rsrc
,
idev
,
ndev
,
req
.
tokens
,
req
.
ntok
,
req
.
req_lens
,
req
.
nreq
,
req
.
req_pos
,
req
.
kv_caches
,
req
.
temperature
,
req
.
topk
,
req
.
topp
,
req
.
output
);
req
.
temperature
,
req
.
topk
,
req
.
topp
,
req
.
output
,
req
.
logits
);
state
.
proceed
=
false
;
lock
.
unlock
();
...
...
src/models/jiuge/jiuge_impl.hpp
View file @
ee495225
...
...
@@ -49,6 +49,7 @@ struct InferRequest {
const
uint32_t
*
topk
;
const
float
*
topp
;
uint32_t
*
output
;
void
*
logits
;
};
struct
JiugeModel
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment