Unverified Commit 31895e5b authored by Avelina Asada Hadji-Kyriacou's avatar Avelina Asada Hadji-Kyriacou Committed by GitHub
Browse files

Added mixed_precision_dtype arg (#3138)

parent 2ea6114e
...@@ -76,6 +76,7 @@ class HFLM(TemplateLM): ...@@ -76,6 +76,7 @@ class HFLM(TemplateLM):
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
softmax_dtype: Optional[Union[str, torch.dtype]] = None, softmax_dtype: Optional[Union[str, torch.dtype]] = None,
mixed_precision_dtype: Optional[Union[str, torch.dtype]] = None,
batch_size: Optional[Union[int, str]] = 1, batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 64, max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
...@@ -247,6 +248,11 @@ class HFLM(TemplateLM): ...@@ -247,6 +248,11 @@ class HFLM(TemplateLM):
self.softmax_dtype = ( self.softmax_dtype = (
get_dtype(softmax_dtype) if softmax_dtype is not None else None get_dtype(softmax_dtype) if softmax_dtype is not None else None
) )
self.mixed_precision_dtype = (
get_dtype(mixed_precision_dtype)
if mixed_precision_dtype is not None
else None
)
if str(batch_size).startswith("auto"): if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":") batch_size = batch_size.split(":")
...@@ -903,6 +909,11 @@ class HFLM(TemplateLM): ...@@ -903,6 +909,11 @@ class HFLM(TemplateLM):
logits returned from the model's decoder logits returned from the model's decoder
""" """
with torch.no_grad(): with torch.no_grad():
with torch.autocast(
device_type=self.device.type,
dtype=self.mixed_precision_dtype,
enabled=self.mixed_precision_dtype is not None,
):
if attn_mask is not None or labels is not None: if attn_mask is not None or labels is not None:
assert attn_mask is not None and labels is not None assert attn_mask is not None and labels is not None
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
...@@ -934,6 +945,11 @@ class HFLM(TemplateLM): ...@@ -934,6 +945,11 @@ class HFLM(TemplateLM):
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, context.shape[1], context.shape[0] self.tokenizer, stop, context.shape[1], context.shape[0]
) )
with torch.autocast(
device_type=self.device.type,
dtype=self.mixed_precision_dtype,
enabled=self.mixed_precision_dtype is not None,
):
return self.model.generate( return self.model.generate(
input_ids=context, input_ids=context,
max_length=max_length, max_length=max_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment