Commit a28c0193 authored by svenhendrikx's avatar svenhendrikx
Browse files

Allow HFLM model to be initialized with transformers.PreTrainedModel instance

parent b21c8f3d
...@@ -32,10 +32,34 @@ class HFLM(BaseLM): ...@@ -32,10 +32,34 @@ class HFLM(BaseLM):
): ):
super().__init__() super().__init__()
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, (int, str))
# Initialize model
if isinstance(pretrained, transformers.PreTrainedModel):
self.gpt2 = pretrained
self._device = self.gpt2.device
if tokenizer:
assert isinstance(
tokenizer,
transformers.PreTrainedTokenizer
) or isinstance(
tokenizer,
transformers.PreTrainedTokenizerFast
)
self.tokenizer = tokenizer
else:
# Get tokenizer
model_name = self.gpt2.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
)
else:
# Initialize device
assert isinstance(device, str)
device_list = set( device_list = set(
["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())] ["cuda", "cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
) )
...@@ -50,8 +74,8 @@ class HFLM(BaseLM): ...@@ -50,8 +74,8 @@ class HFLM(BaseLM):
if torch.cuda.is_available() if torch.cuda.is_available()
else torch.device("cpu") else torch.device("cpu")
) )
assert isinstance(pretrained, str)
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
...@@ -62,14 +86,15 @@ class HFLM(BaseLM): ...@@ -62,14 +86,15 @@ class HFLM(BaseLM):
torch_dtype=_get_dtype(dtype), torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
).to(self.device) ).to(self.device)
self.gpt2.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, tokenizer if tokenizer else pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
self.gpt2.eval()
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
if isinstance( if isinstance(
...@@ -82,6 +107,8 @@ class HFLM(BaseLM): ...@@ -82,6 +107,8 @@ class HFLM(BaseLM):
31373, 31373,
], self.tokenizer.encode("hello\n\nhello") ], self.tokenizer.encode("hello\n\nhello")
# Validate batch_size
assert isinstance(batch_size, (int, str))
# setup for automatic batch size detection # setup for automatic batch size detection
if batch_size == "auto": if batch_size == "auto":
self.batch_size_per_gpu = batch_size self.batch_size_per_gpu = batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment