"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "aec51e56960545ae3ee192f49524653ad27343c4"
Unverified Commit ce37be9d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] warn if --fp16 for torch 1.6 (#6977)

parent f72fe1f3
...@@ -3,6 +3,7 @@ import glob ...@@ -3,6 +3,7 @@ import glob
import logging import logging
import os import os
import time import time
import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -10,6 +11,7 @@ from typing import Dict, List, Tuple ...@@ -10,6 +11,7 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from packaging import version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train from lightning_base import BaseTransformer, add_generic_args, generic_train
...@@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule: ...@@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args) model: SummarizationModule = SummarizationModule(args)
else: else:
model: SummarizationModule = TranslationModule(args) model: SummarizationModule = TranslationModule(args)
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
warnings.warn("FP16 only seems to work with torch 1.5+apex")
dataset = Path(args.data_dir).name dataset = Path(args.data_dir).name
if ( if (
args.logger_name == "default" args.logger_name == "default"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment