Commit ac664bce authored by baberabb's avatar baberabb
Browse files

fix logliklehood_rolling

parent 2362ab41
...@@ -37,7 +37,7 @@ class VLLM(LM): ...@@ -37,7 +37,7 @@ class VLLM(LM):
self.model = LLM( self.model = LLM(
model=pretrained, model=pretrained,
gpu_memory_utilization=0.2, gpu_memory_utilization=0.9,
revision=revision, revision=revision,
dtype=dtype, dtype=dtype,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
...@@ -135,7 +135,7 @@ class VLLM(LM): ...@@ -135,7 +135,7 @@ class VLLM(LM):
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.eot_token_id,
max_seq_len=self.max_length, max_seq_len=self.max_length - 1,
context_len=1, context_len=1,
), ),
) )
...@@ -331,7 +331,9 @@ class VLLM(LM): ...@@ -331,7 +331,9 @@ class VLLM(LM):
# Determine if is_greedy # Determine if is_greedy
is_greedy = True is_greedy = True
for token, logprob_dict in zip(tokens[ctxlen:], continuation_logprobs_dicts): for token, logprob_dict in zip(
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
):
# Get the token with the maximum log probability from the logprob_dict # Get the token with the maximum log probability from the logprob_dict
if logprob_dict: # Ensure the logprob_dict is not None if logprob_dict: # Ensure the logprob_dict is not None
top_token = max(logprob_dict, key=logprob_dict.get) top_token = max(logprob_dict, key=logprob_dict.get)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment