Commit ee495225 authored by PanZezhong's avatar PanZezhong
Browse files

add perplexity test

parent 07aa6990
......@@ -75,7 +75,7 @@ __C __export void
dropKVCache(const struct JiugeModel *,
struct KVCache *);
/// @brief 批次推理一轮
/// @brief 批次推理一轮,并采样出新的 token
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
......@@ -94,4 +94,19 @@ inferBatch(struct JiugeModel *,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output);
/// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
forwardBatch(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits);
#endif
from typing import List
from typing import List, Sequence
from sympy import true
from libinfinicore_infer import (
JiugeMetaCStruct,
JiugeWeightsCStruct,
......@@ -10,6 +12,7 @@ from libinfinicore_infer import (
create_kv_cache,
drop_kv_cache,
infer_batch,
forward_batch,
)
from infer_task import InferTask, KVCache
......@@ -582,6 +585,59 @@ class JiugeForCauslLM:
infer_task._kv_cache.drop(self)
return output_content, avg_time
def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10):
tasks = [
InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id)
for i in range(batch_size)
]
kv_caches = [KVCache(self) for _ in range(batch_size)]
nll = 0.0
total_len = 0
for i in range(0, len(test_sequences), batch_size):
batch_id = 0
true_tokens = []
while batch_id < batch_size and batch_id + i < len(test_sequences):
input_tokens = test_sequences[i + batch_id][:-1]
true_tokens.extend(test_sequences[i + batch_id][1:])
tasks[batch_id].tokens = input_tokens
tasks[batch_id].bind_kvcache(kv_caches[batch_id])
batch_id += 1
batch_inputs = JiugeBatchedTask(tasks[:batch_id])
logits = torch.zeros(
(batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
)
forward_batch(
self.model_instance,
batch_inputs.tokens,
batch_inputs.ntok,
batch_inputs.req_lens,
batch_inputs.nreq,
batch_inputs.req_pos,
batch_inputs.kv_caches,
logits.data_ptr(),
)
logits = logits.float()
token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,]
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab)
token_logprobs = log_probs[
torch.arange(batch_inputs.ntok), token_ids
] # (ntok,)
start = 0
for l in batch_inputs.req_lens_list:
nll += -token_logprobs[start : start + l].sum().item()
start += l
total_len += token_logprobs.numel()
for task in tasks:
task.release_kvcache()
return math.exp(nll / total_len)
def destroy_model_instance(self):
destroy_jiuge_model(self.model_instance)
print("Model destroyed")
......
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from jiuge import JiugeForCauslLM
from libinfinicore_infer import DeviceType
DEVICE_TYPE_MAP = {
"cpu": DeviceType.DEVICE_TYPE_CPU,
"nvidia": DeviceType.DEVICE_TYPE_NVIDIA,
"cambricon": DeviceType.DEVICE_TYPE_CAMBRICON,
"ascend": DeviceType.DEVICE_TYPE_ASCEND,
"metax": DeviceType.DEVICE_TYPE_METAX,
"moore": DeviceType.DEVICE_TYPE_MOORE,
}
TORCH_DEVICE_TYPE_MAP = {
"cpu": "cpu",
"nvidia": "cuda",
"cambricon": "mlu",
"ascend": "npu",
"metax": "cuda",
"moore": "cuda",
}
def test_torch(input_ids_list, device_):
device = TORCH_DEVICE_TYPE_MAP[device_]
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(
device
)
model.eval()
total_neg_log_likelihood = 0
total_tokens = 0
with torch.no_grad():
for input_ids in input_ids_list:
input_ids = torch.tensor(input_ids, device=device)
# shift inputs and labels
inputs = input_ids[:-1].unsqueeze(0) # [1, seq_len-1]
labels = input_ids[1:].unsqueeze(0) # [1, seq_len-1]
outputs = model(inputs, use_cache=False)
logits = outputs.logits # [1, seq_len-1, vocab_size]
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
# gather log probs of true tokens
true_token_log_probs = log_probs.gather(
dim=-1, index=labels.unsqueeze(-1)
).squeeze(-1)
total_neg_log_likelihood += -true_token_log_probs.sum().item()
total_tokens += labels.numel()
perplexity = torch.exp(torch.tensor(total_neg_log_likelihood / total_tokens))
return perplexity
def test_infinicore(input_ids_list, device_, ndev_):
device = DEVICE_TYPE_MAP[device_]
model = JiugeForCauslLM(
model_path, device, max_tokens=len(input_ids_list[0]), ndev=ndev_
)
perplexity = model.perplexity(input_ids_list)
model.destroy_model_instance()
return perplexity
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument(
"--dev", type=str, default="cpu", choices=DEVICE_TYPE_MAP.keys()
)
parser.add_argument(
"--ndev",
type=int,
default=1,
help="Number of devices to use (default: 1)",
)
args = parser.parse_args()
seq_len = 512
model_path = args.model_path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
texts = dataset["text"]
texts = [t for t in texts if len(t.strip()) > 0]
input_ids_list = []
for text in texts:
ids = tokenizer.encode(text)
# split long sequences into chunks
for i in range(0, len(ids) - seq_len + 1, seq_len):
input_ids_list.append(ids[i : i + seq_len])
perplexity = test_infinicore(input_ids_list, args.dev, args.ndev)
print(f"InfiniCore Perplexity: {perplexity:.2f}")
if args.ndev == 1: # Todo: support multi-device testing with torch
perplexity = test_torch(input_ids_list, args.dev)
print(f"Torch Perplexity: {perplexity.item():.2f}")
......@@ -112,6 +112,17 @@ def __open_library__():
POINTER(c_float), # float topp
POINTER(c_uint), # unsigned int *output
]
lib.forwardBatch.restype = None
lib.forwardBatch.argtypes = [
POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok
POINTER(c_uint), # unsigned int const *req_lens
c_uint, # unsigned int nreq
POINTER(c_uint), # unsigned int const *req_pos
POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches
c_void_p, # void *logits
]
return lib
......@@ -123,3 +134,4 @@ destroy_jiuge_model = LIB.destroyJiugeModel
create_kv_cache = LIB.createKVCache
drop_kv_cache = LIB.dropKVCache
infer_batch = LIB.inferBatch
forward_batch = LIB.forwardBatch
import math
import requests
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--endpoint", type=str, default="/completions")
parser.add_argument("--chunk", type=int, default=512)
args = parser.parse_args()
API_URL = "http://localhost:" + str(args.port) + args.endpoint
CHUNK_SIZE = args.chunk
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
# Local tokenizer used for chunking
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
total_neg_log_likelihood = 0.0
total_tokens = 0
for example in tqdm(dataset, desc="Evaluating PPL"):
text = example["text"].strip()
if not text:
continue
# endcode, chunk and decode
tokens = tokenizer.encode(text, add_special_tokens=False)
for i in range(0, len(tokens), CHUNK_SIZE):
chunk_tokens = tokens[i : min(i + CHUNK_SIZE, len(tokens))]
chunk_text = tokenizer.decode(chunk_tokens)
resp = requests.post(
API_URL,
headers={"Content-Type": "application/json"},
json={
"model": "",
"prompt": chunk_text,
"max_tokens": 0,
"temperature": 1.0,
"echo": True,
"logprobs": 0,
},
).json()
logprobs = resp["choices"][0]["logprobs"]["token_logprobs"]
# skip first token's None
valid_logprobs = [lp for lp in logprobs[1:] if lp is not None]
total_neg_log_likelihood += -sum(valid_logprobs)
total_tokens += len(valid_logprobs)
# ==== Compute final PPL ====
ppl = math.exp(total_neg_log_likelihood / total_tokens)
print(f"Perplexity: {ppl:.4f}")
import requests
import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
API_URL = "http://localhost:8000/chat/completions"
MODEL = "FM9G-7B"
PROMPT = ["山东最高的山是?", "给我讲个故事"]
CONCURRENCY = 10 # 并发用户数量
def single_run(user_id):
payload = {
"model": MODEL,
"messages": [{"role": "user", "content": PROMPT[user_id % len(PROMPT)]}],
"max_tokens": 512,
"stream": True
}
headers = {'Content-Type': 'application/json', 'Accept': 'application/json'}
print(f"[User {user_id}] Sending request...")
start = time.perf_counter()
resp = requests.post(API_URL, headers=headers, json=payload, stream=True)
resp.raise_for_status()
ttfb = resp.elapsed.total_seconds() # HTTP header 到达时间
header_received = time.perf_counter()
if resp.encoding is None:
resp.encoding = 'utf-8'
tokens = 0
chunks = []
for line in resp.iter_lines(decode_unicode=True):
if not line or line.strip() == "[DONE]":
continue
s = line.strip()
if s.startswith("data:"):
s = s[len("data:"):].strip()
try:
data = json.loads(s)
except json.JSONDecodeError:
continue
text = data.get("choices", [{}])[0].get("delta", {}).get("content")
if text:
chunks.append(text)
tokens += 1
stream_done = time.perf_counter()
# 时间计算
stream_time = stream_done - header_received
total_time = stream_done - start
time_per_token_ms = (stream_time / tokens * 1000) if tokens else float('inf')
tps = tokens / stream_time if stream_time > 0 else 0
return {
"user": user_id,
"ttfb": ttfb,
"stream_time": stream_time,
"total_time": total_time,
"tokens": tokens,
"time_per_token_ms": time_per_token_ms,
"tps": tps,
"chunks": chunks
}
def main():
worst = None
worst_stream = -1.0
best_stream = float('inf')
results = []
with ThreadPoolExecutor(max_workers=CONCURRENCY) as e:
futures = [e.submit(single_run, uid) for uid in range(CONCURRENCY)]
for future in as_completed(futures):
r = future.result()
results.append(r)
print(
f"User {r['user']} → TTFB = {r['ttfb']:.3f}s, latency = {r['stream_time']:.3f}s, "
f"tokens = {r['tokens']}, time/token = {r['time_per_token_ms']:.2f} ms, "
f"TPS = {r['tps']:.1f} tok/s"
)
if r['stream_time'] > worst_stream:
worst_stream = r['stream_time']
worst = r
if r['stream_time'] < best_stream:
best_stream = r['stream_time']
best = r
# Sort results by user ID
results.sort(key=lambda x: x["user"])
with open("responses.txt", "w", encoding="utf-8") as fw:
for r in results:
fw.write(f"[User {r['user']}]\n")
text = "".join(r["chunks"])
# fixed = text.encode('latin-1').decode('utf-8')
fixed = text
fw.write(fixed)
fw.write("\n\n")
n = CONCURRENCY
avg_ttfb = sum(r['ttfb'] for r in results) / n
avg_token = sum(r['tokens'] for r in results) / n
avg_stream = sum(r['stream_time'] for r in results) / n
avg_tps = sum(r['tps'] for r in results) / n
avg_time_per_token = sum(r['time_per_token_ms'] for r in results) / n
print(f"\n✅ All {n} requests completed.")
print(f"Averages → TTFB = {avg_ttfb:.3f}s, latency = {avg_stream:.3f}s, "
f"tokens = {avg_token:.1f}, TPS = {avg_tps:.1f} tok/s, time/token = {avg_time_per_token:.2f} ms")
if best:
print("\nFastest user:")
print(
f"User {best['user']} → latency = {best['stream_time']:.3f}s, "
f"tokens = {best['tokens']}, TPS = {best['tps']:.1f} tok/s, "
f"time/token = {best['time_per_token_ms']:.2f} ms"
)
if worst:
print("\nSlowest user:")
print(
f"User {worst['user']} → latency = {worst['stream_time']:.3f}s, "
f"tokens = {worst['tokens']}, TPS = {worst['tps']:.1f} tok/s, "
f"time/token = {worst['time_per_token_ms']:.2f} ms"
)
if __name__ == "__main__":
main()
......@@ -117,7 +117,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output) {
uint32_t *output, void *last_logits) {
auto nlayer = meta.nlayer;
auto nkvh = meta.nkvh / ndev;
auto nh = meta.nh / ndev;
......@@ -220,12 +220,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(q_rearrange, q);
auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf, k_gemm, 1. / sqrt(dh), 0.0, nullptr, nullptr);
linear(qk_gemm, rearrange_q_buf, k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr);
// softmax
auto qk_softmax = qk_buf->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf, qk_gemm, v_gemm, 1.0, 0.0, nullptr, nullptr);
linear(attn_val_buf, qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr);
// rearrange attn val
rearrange(o, attn_val_gemm);
......@@ -258,32 +258,41 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
// Sample and Output
if (idev == 0) {
size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
token_offset += seq_len;
rmsnorm(logits_out->slice(0, req, 1),
logits_in->slice(0, token_offset - 1, 1),
rsrc.w_out_norm,
meta.epsilon);
}
linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr);
std::random_device _rd;
std::mt19937 gen(_rd());
token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
randomSample(result_buf->memShare({}, result_buf->dtype()),
prob_buf->view_as({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]);
token_offset += seq_len;
if (last_logits != nullptr) {
rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon);
auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool);
linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr);
RUN_INFINI(infinirtStreamSynchronize(stream));
RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H));
}
RUN_INFINI(infinirtStreamSynchronize(stream));
RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(),
sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H));
for (uint32_t req = 0; req < nreq; req++) {
output[req] = result_cpu[req];
if (output != nullptr) {
size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
token_offset += seq_len;
rmsnorm(logits_out->slice(0, req, 1),
logits_in->slice(0, token_offset - 1, 1),
rsrc.w_out_norm,
meta.epsilon);
}
linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr);
std::random_device _rd;
std::mt19937 gen(_rd());
token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
randomSample(result_buf->memShare({}, result_buf->dtype()),
prob_buf->view_as({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]);
token_offset += seq_len;
}
RUN_INFINI(infinirtStreamSynchronize(stream));
RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(),
sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H));
for (uint32_t req = 0; req < nreq; req++) {
output[req] = uint32_t(result_cpu[req]);
}
}
}
}
......@@ -302,6 +311,7 @@ inferBatch(struct JiugeModel *model,
model->req.req_pos = req_pos;
model->req.kv_caches = kv_caches;
model->req.output = output;
model->req.logits = nullptr;
model->req.temperature = temperature;
model->req.topk = topk;
model->req.topp = topp;
......@@ -320,6 +330,38 @@ inferBatch(struct JiugeModel *model,
}
}
__C void
forwardBatch(struct JiugeModel *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits) {
model->req.tokens = tokens;
model->req.ntok = ntok;
model->req.req_lens = req_lens;
model->req.nreq = nreq;
model->req.req_pos = req_pos;
model->req.kv_caches = kv_caches;
model->req.output = nullptr;
model->req.logits = logits;
model->req.temperature = nullptr;
model->req.topk = nullptr;
model->req.topp = nullptr;
for (size_t idev = 0; idev < model->dev_ids.size(); idev++) {
std::unique_lock<std::mutex> lock(model->states[idev].mtx);
model->states[idev].proceed = true;
lock.unlock();
model->states[idev].cv_start.notify_one();
}
for (size_t i = model->dev_ids.size(); i > 0; i--) {
auto idev = i - 1;
std::unique_lock<std::mutex> lock(model->states[idev].mtx);
model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); });
lock.unlock();
}
}
void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceResource *rsrc, InferState &state, InferRequest &req,
infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
CacheManager cache_manager(100);
......@@ -348,7 +390,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok,
req.req_lens, req.nreq, req.req_pos, req.kv_caches,
req.temperature, req.topk, req.topp, req.output);
req.temperature, req.topk, req.topp, req.output, req.logits);
state.proceed = false;
lock.unlock();
......
......@@ -49,6 +49,7 @@ struct InferRequest {
const uint32_t *topk;
const float *topp;
uint32_t *output;
void *logits;
};
struct JiugeModel {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment