Commit b569f483 authored by Leo Gao's avatar Leo Gao
Browse files

add support for HF revision

parent 3b81a361
......@@ -13,7 +13,7 @@ class GPT2LM(LM):
VOCAB_SIZE = 50257
EOT_TOKEN_ID = 50256
def __init__(self, device='cuda', pretrained='gpt2', batch_size=1):
def __init__(self, device='cuda', pretrained='gpt2', revision="main", batch_size=1):
super().__init__()
assert isinstance(device, str)
......@@ -24,7 +24,7 @@ class GPT2LM(LM):
self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained, revision=revision).to(self.device)
self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment