"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "59602eea096767bece389937c9e2031abbe21275"
Unverified Commit 13842e41 authored by Joe Davison's avatar Joe Davison Committed by GitHub
Browse files

PPL guide minor code snippet fix (#7938)

parent 0e24e4c1
...@@ -125,18 +125,19 @@ are 512 preceding tokens available to condition on). ...@@ -125,18 +125,19 @@ are 512 preceding tokens available to condition on).
lls = [] lls = []
for i in tqdm(range(0, encodings.input_ids.size(1), stride)): for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
begin_loc = max(i + stride - max_length, 0) begin_loc = max(i + stride - max_length, 0)
end_loc = i + stride end_loc = min(i + stride, encodings.input_ids.size(1))
trg_len = end_loc - i # may be different from stride on last loop
input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device) input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
target_ids = input_ids.clone() target_ids = input_ids.clone()
target_ids[:,:-stride] = -100 target_ids[:,:-trg_len] = -100
with torch.no_grad(): with torch.no_grad():
outputs = model(input_ids, labels=target_ids) outputs = model(input_ids, labels=target_ids)
log_likelihood = outputs[0] * stride log_likelihood = outputs[0] * trg_len
lls.append(log_likelihood) lls.append(log_likelihood)
ppl = torch.exp(torch.stack(lls).sum() / i) ppl = torch.exp(torch.stack(lls).sum() / end_loc)
Running this with the stride length equal to the max input length is Running this with the stride length equal to the max input length is
equivalent to the suboptimal, non-sliding-window strategy we discussed above. equivalent to the suboptimal, non-sliding-window strategy we discussed above.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment