Commit e63d1396 authored by Leo Gao's avatar Leo Gao
Browse files

add warning

parent 95bc8317
...@@ -23,7 +23,7 @@ def get_result(response, ctxlen): ...@@ -23,7 +23,7 @@ def get_result(response, ctxlen):
is_greedy = True is_greedy = True
logprobs = response["logprobs"]["token_logprobs"][:-1] logprobs = response["logprobs"]["token_logprobs"][:-1]
continuation_logprobs = sum(logprobs[ctxlen:]) continuation_logprobs = sum(logprobs[ctxlen:])
print(logprobs[ctxlen:]) # print(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"][:-1])): for i in range(ctxlen, len(response["logprobs"]["tokens"][:-1])):
token = response["logprobs"]["tokens"][:-1][i] token = response["logprobs"]["tokens"][:-1][i]
...@@ -83,6 +83,8 @@ class GPT3LM(BaseLM): ...@@ -83,6 +83,8 @@ class GPT3LM(BaseLM):
""" """
super().__init__() super().__init__()
assert pass_strings, "so far, this branch only supports GooseAI, and won't work with the regular OpenAI api. this is mostly because there are still some remaining differences between the two apis that make this more complicated than just a drop in replacement. there's no fundamental reason why I couldn't support both on the same branch right now, but it would be a lot of work, and once gooseai finally makes their api conform to the openai api then we won't need this branch anymore and I'll implement something more simple once that does actually happen."
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
...@@ -149,7 +151,7 @@ class GPT3LM(BaseLM): ...@@ -149,7 +151,7 @@ class GPT3LM(BaseLM):
# TODO: the logic is much simpler if we just look at the length of continuation tokens # TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1)) ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1))
print(inp) # print(inp)
if self.pass_strings: if self.pass_strings:
inp = self.tok_decode(inp) inp = self.tok_decode(inp)
inps.append(inp) inps.append(inp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment