Commit a28c0193 authored by svenhendrikx's avatar svenhendrikx
Browse files

Allow HFLM model to be initialized with transformers.PreTrainedModel instance

parent b21c8f3d
......@@ -31,44 +31,69 @@ class HFLM(BaseLM):
dtype: Optional[Union[str, torch.dtype]]="auto",
):
super().__init__()
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, (int, str))
device_list = set(
["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device and device in device_list:
self._device = torch.device(device)
print(f"Using device '{device}'")
# Initialize model
if isinstance(pretrained, transformers.PreTrainedModel):
self.gpt2 = pretrained
self._device = self.gpt2.device
if tokenizer:
assert isinstance(
tokenizer,
transformers.PreTrainedTokenizer
) or isinstance(
tokenizer,
transformers.PreTrainedTokenizerFast
)
self.tokenizer = tokenizer
else:
# Get tokenizer
model_name = self.gpt2.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
)
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
# Initialize device
assert isinstance(device, str)
device_list = set(
["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device and device in device_list:
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
assert isinstance(pretrained, str)
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code,
).to(self.device)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer if tokenizer else pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code,
).to(self.device)
self.gpt2.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
)
self.gpt2.eval()
self.vocab_size = self.tokenizer.vocab_size
......@@ -82,6 +107,8 @@ class HFLM(BaseLM):
31373,
], self.tokenizer.encode("hello\n\nhello")
# Validate batch_size
assert isinstance(batch_size, (int, str))
# setup for automatic batch size detection
if batch_size == "auto":
self.batch_size_per_gpu = batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment