Commit 341663a9 authored by researcher2's avatar researcher2
Browse files

Small fixes

parent 55e62507
......@@ -15,12 +15,12 @@ class HFLM(BaseLM):
if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self.device = torch.device(device)
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specificed")
print(f"Cuda Available? {torch.cuda.is_available()}")
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
......
......@@ -43,6 +43,7 @@ def parse_args():
parser.add_argument('--decontaminate', action="store_true")
parser.add_argument('--ngrams_path', default=None)
parser.add_argument('--ngrams_n_size', type=int, default=None)
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args()
......@@ -67,11 +68,6 @@ def pattern_match(patterns, source_list):
task_names.add(matching)
return list(task_names)
def main():
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args()
def main():
args = parse_args()
if not ensure_correct_decontamination_params(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment