Unverified Commit 31895e5b authored by Avelina Asada Hadji-Kyriacou's avatar Avelina Asada Hadji-Kyriacou Committed by GitHub
Browse files

Added mixed_precision_dtype arg (#3138)

parent 2ea6114e
......@@ -76,6 +76,7 @@ class HFLM(TemplateLM):
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
softmax_dtype: Optional[Union[str, torch.dtype]] = None,
mixed_precision_dtype: Optional[Union[str, torch.dtype]] = None,
batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = False,
......@@ -247,6 +248,11 @@ class HFLM(TemplateLM):
self.softmax_dtype = (
get_dtype(softmax_dtype) if softmax_dtype is not None else None
)
self.mixed_precision_dtype = (
get_dtype(mixed_precision_dtype)
if mixed_precision_dtype is not None
else None
)
if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":")
......@@ -903,18 +909,23 @@ class HFLM(TemplateLM):
logits returned from the model's decoder
"""
with torch.no_grad():
if attn_mask is not None or labels is not None:
assert attn_mask is not None and labels is not None
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
else:
assert self.AUTO_MODEL_CLASS in (
transformers.AutoModelForCausalLM,
transformers.AutoModelForVision2Seq,
)
return self.model(inps).logits
with torch.autocast(
device_type=self.device.type,
dtype=self.mixed_precision_dtype,
enabled=self.mixed_precision_dtype is not None,
):
if attn_mask is not None or labels is not None:
assert attn_mask is not None and labels is not None
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
else:
assert self.AUTO_MODEL_CLASS in (
transformers.AutoModelForCausalLM,
transformers.AutoModelForVision2Seq,
)
return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# temperature = 0.0 if not set
......@@ -934,14 +945,19 @@ class HFLM(TemplateLM):
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, context.shape[1], context.shape[0]
)
return self.model.generate(
input_ids=context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
**generation_kwargs,
)
with torch.autocast(
device_type=self.device.type,
dtype=self.mixed_precision_dtype,
enabled=self.mixed_precision_dtype is not None,
):
return self.model.generate(
input_ids=context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
**generation_kwargs,
)
def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment