Unverified Commit 31895e5b authored by Avelina Asada Hadji-Kyriacou's avatar Avelina Asada Hadji-Kyriacou Committed by GitHub
Browse files

Added mixed_precision_dtype arg (#3138)

parent 2ea6114e
...@@ -76,6 +76,7 @@ class HFLM(TemplateLM): ...@@ -76,6 +76,7 @@ class HFLM(TemplateLM):
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
softmax_dtype: Optional[Union[str, torch.dtype]] = None, softmax_dtype: Optional[Union[str, torch.dtype]] = None,
mixed_precision_dtype: Optional[Union[str, torch.dtype]] = None,
batch_size: Optional[Union[int, str]] = 1, batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 64, max_batch_size: Optional[int] = 64,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
...@@ -247,6 +248,11 @@ class HFLM(TemplateLM): ...@@ -247,6 +248,11 @@ class HFLM(TemplateLM):
self.softmax_dtype = ( self.softmax_dtype = (
get_dtype(softmax_dtype) if softmax_dtype is not None else None get_dtype(softmax_dtype) if softmax_dtype is not None else None
) )
self.mixed_precision_dtype = (
get_dtype(mixed_precision_dtype)
if mixed_precision_dtype is not None
else None
)
if str(batch_size).startswith("auto"): if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":") batch_size = batch_size.split(":")
...@@ -903,18 +909,23 @@ class HFLM(TemplateLM): ...@@ -903,18 +909,23 @@ class HFLM(TemplateLM):
logits returned from the model's decoder logits returned from the model's decoder
""" """
with torch.no_grad(): with torch.no_grad():
if attn_mask is not None or labels is not None: with torch.autocast(
assert attn_mask is not None and labels is not None device_type=self.device.type,
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM dtype=self.mixed_precision_dtype,
return self.model( enabled=self.mixed_precision_dtype is not None,
input_ids=inps, attention_mask=attn_mask, labels=labels ):
).logits if attn_mask is not None or labels is not None:
else: assert attn_mask is not None and labels is not None
assert self.AUTO_MODEL_CLASS in ( assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
transformers.AutoModelForCausalLM, return self.model(
transformers.AutoModelForVision2Seq, input_ids=inps, attention_mask=attn_mask, labels=labels
) ).logits
return self.model(inps).logits else:
assert self.AUTO_MODEL_CLASS in (
transformers.AutoModelForCausalLM,
transformers.AutoModelForVision2Seq,
)
return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# temperature = 0.0 if not set # temperature = 0.0 if not set
...@@ -934,14 +945,19 @@ class HFLM(TemplateLM): ...@@ -934,14 +945,19 @@ class HFLM(TemplateLM):
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, context.shape[1], context.shape[0] self.tokenizer, stop, context.shape[1], context.shape[0]
) )
return self.model.generate( with torch.autocast(
input_ids=context, device_type=self.device.type,
max_length=max_length, dtype=self.mixed_precision_dtype,
stopping_criteria=stopping_criteria, enabled=self.mixed_precision_dtype is not None,
pad_token_id=self.tokenizer.pad_token_id, ):
use_cache=True, return self.model.generate(
**generation_kwargs, input_ids=context,
) max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
**generation_kwargs,
)
def _select_cont_toks( def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None self, logits: torch.Tensor, contlen: int = None, inplen: int = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment