Unverified Commit ce37be9d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] warn if --fp16 for torch 1.6 (#6977)

parent f72fe1f3
......@@ -3,6 +3,7 @@ import glob
import logging
import os
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
......@@ -10,6 +11,7 @@ from typing import Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
from packaging import version
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
......@@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
warnings.warn("FP16 only seems to work with torch 1.5+apex")
dataset = Path(args.data_dir).name
if (
args.logger_name == "default"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment