Commit df83b67e authored by Michael Carilli's avatar Michael Carilli
Browse files

Cleaning up READMEs

parent 6066ddd5
# PSA: Unified API for mixed precision tools coming soon!
(as introduced by https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html.
Branch `api_refactor` is tracking my progress. Update as of 2/28: PR-ed in https://github.com/NVIDIA/apex/pull/173. I'd like to clean up the documentation a bit more before final merge.
# Introduction # Introduction
This repository holds NVIDIA-maintained utilities to streamline This repository holds NVIDIA-maintained utilities to streamline
...@@ -19,31 +14,20 @@ users as quickly as possible. ...@@ -19,31 +14,20 @@ users as quickly as possible.
### amp: Automatic Mixed Precision ### amp: Automatic Mixed Precision
`apex.amp` is a tool designed for ease of use and maximum safety in FP16 training. All potentially unsafe ops are performed in FP32 under the hood, while safe ops are performed using faster, Tensor Core-friendly FP16 math. `amp` also automatically implements dynamic loss scaling. `apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
The intention of `amp` is to be the "on-ramp" to easy FP16 training: achieve all the numerical stability of full FP32 training, with most of the performance benefits of full FP16 training. different flags to `amp.initialize`.
[Python Source and API Documentation](https://github.com/NVIDIA/apex/tree/master/apex/amp)
### FP16_Optimizer
`apex.FP16_Optimizer` wraps an existing Python optimizer and automatically implements master parameters and static or dynamic loss scaling under the hood. [Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
The intention of `FP16_Optimizer` is to be the "highway" for FP16 training: achieve most of the numerically stability of full FP32 training, and almost all the performance benefits of full FP16 training. [API Documentation](https://nvidia.github.io/apex/amp.html)
[API Documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling) [Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/fp16_utils) [DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
[Simple examples with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple) [Moving to the new Amp API] (for users of the deprecated tools formerly called "Amp" and "FP16_Optimizer")
[Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
[word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model)
The Imagenet and word_language_model directories also contain examples that show manual management of master parameters and static loss scaling.
These manual examples illustrate what sort of operations `amp` and `FP16_Optimizer` are performing automatically.
## 2. Distributed Training ## 2. Distributed Training
...@@ -57,69 +41,60 @@ optimized for NVIDIA's NCCL communication library. ...@@ -57,69 +41,60 @@ optimized for NVIDIA's NCCL communication library.
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed) [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed)
The [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
mixed precision examples also demonstrate `apex.parallel.DistributedDataParallel`. shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
### Synchronized Batch Normalization ### Synchronized Batch Normalization
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to `apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN. support synchronized BN.
It reduces stats across processes during multiprocess distributed data parallel It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
training. Synchronous BN has been used in cases where only a small
Synchronous Batch Normalization has been used in cases where only very small local minibatch can fit on each GPU.
number of mini-batch could be fit on each GPU. Allreduced stats increase the effective batch size for the BN layer to the
All-reduced stats boost the effective batch size for sync BN layer to be the global batch size across all processes (which, technically, is the correct
total number of mini-batches across all processes. formulation).
It has improved the converged accuracy in some of our research models. Synchronous BN has been observed to improve converged accuracy in some of our research models.
# Requirements # Requirements
Python 3 Python 3
CUDA 9 or 10 CUDA 9 or newer
PyTorch 0.4 or newer. We recommend to use the latest stable release, obtainable from PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.
[https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
If you have any problems building, please file an issue.
The cpp and cuda extensions require pytorch 1.0 or newer.
We recommend the latest stable release, obtainable from
[https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
It's often convenient to use Apex in Docker containers. Compatible options include:
* [NVIDIA Pytorch containers from NGC](https://ngc.nvidia.com/catalog/containers/nvidia%2Fpytorch), which come with Apex preinstalled. To use the latest Amp API, you may need to `pip uninstall apex` then reinstall Apex using the **Quick Start** commands below.
* [official Pytorch -devel Dockerfiles](https://hub.docker.com/r/pytorch/pytorch/tags), e.g. `docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7`, in which you can install Apex using the **Quick Start** commands.
# Quick Start # Quick Start
### Linux ### Linux
To build the extension run
```
python setup.py install
```
in the root directory of the cloned repository.
To use the extension For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via
``` ```
import apex $ git clone apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
``` ```
### CUDA/C++ extension Apex also supports a Python-only build (required with Pytorch 0.4) via
Apex contains optional CUDA/C++ extensions, installable via
``` ```
python setup.py install [--cuda_ext] [--cpp_ext] $ pip install -v --no-cache-dir .
``` ```
Currently, `--cuda_ext` enables A Python-only build omits:
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels required to use `apex.optimizers.FusedAdam`. - Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm`. - Fused kernels required to use `apex.normalization.FusedLayerNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
`--cpp_ext` enables - Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
- C++-side flattening and unflattening utilities that reduce the CPU overhead of `apex.parallel.DistributedDataParallel`. `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
### Windows support ### Windows support
Windows support is experimental, and Linux is recommended. However, since Apex could be Python-only, there's a good chance the Python-only features "just works" the same way as Linux. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment. Windows support is experimental, and Linux is recommended. `python setup.py install --cpp_ext --cuda_ext` may work if you were able to build Pytorch from source
on your system. `python setup.py install` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment,
<!-- make sure to install Apex in that same environment.
reparametrization and RNN API under construction
Current version of apex contains:
3. Reparameterization function that allows you to recursively apply reparameterization to an entire module (including children modules).
4. An experimental and in development flexible RNN API.
-->
# amp: Automatic Mixed Precision # amp: Automatic Mixed Precision
## This README documents the legacy (pre-Amp 1.0) API. ## This README documents the deprecated (pre-unified) API.
## Documentation for the new 1.0 API can be found [here](https://nvidia.github.io/apex/) ## Documentation for the current unified API can be found [here](https://nvidia.github.io/apex/)
amp is an experimental tool to enable mixed precision training in amp is an experimental tool to enable mixed precision training in
PyTorch with extreme simplicity and overall numerical safety. It PyTorch with extreme simplicity and overall numerical safety. It
......
...@@ -3,7 +3,15 @@ ...@@ -3,7 +3,15 @@
# But apparently it's ok: # But apparently it's ok:
# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm # http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
class AmpState(object): class AmpState(object):
pass def __init__(self):
self.hard_override=False
# Attribute stash. Could also just stash things as global module attributes. # Attribute stash. Could also just stash things as global module attributes.
_amp_state = AmpState() _amp_state = AmpState()
def warn_or_err(msg):
if _amp_state.hard_override:
print("Warning: " + msg)
else:
raise RuntimeError(msg + " If you're sure you know what you're doing, supply " +
"hard_override=True to amp.initialize.")
import torch import torch
from torch._six import container_abcs, string_classes from torch._six import container_abcs, string_classes
import functools import functools
from ._amp_state import _amp_state from ._amp_state import _amp_state, warn_or_err
from .handle import disable_casts from .handle import disable_casts
from .scaler import LossScaler from .scaler import LossScaler
from apex.fp16_utils import convert_network from apex.fp16_utils import convert_network
...@@ -13,9 +13,11 @@ from ..parallel import DistributedDataParallel as apex_DDP ...@@ -13,9 +13,11 @@ from ..parallel import DistributedDataParallel as apex_DDP
def to_type(dtype, t): def to_type(dtype, t):
if not t.is_cuda: if not t.is_cuda:
print("Warning: input tensor was not cuda. Call .cuda() on your data before passing it.") # This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ")
if t.requires_grad: if t.requires_grad:
print("Warning: input data requires grad. Since input data is not a model parameter,\n" # This should be a hard-ish error.
warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.") "its gradients will not be properly allreduced by DDP.")
if t.is_floating_point(): if t.is_floating_point():
return t.to(dtype) return t.to(dtype)
...@@ -55,14 +57,14 @@ def check_params_fp32(models): ...@@ -55,14 +57,14 @@ def check_params_fp32(models):
for model in models: for model in models:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.is_floating_point() and param.type() != "torch.cuda.FloatTensor": if param.is_floating_point() and param.type() != "torch.cuda.FloatTensor":
print("Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.\n" warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n" "When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format( "before passing it, no matter what optimization level you choose.".format(
name, param.type())) name, param.type()))
for name, buf in model.named_buffers(): for name, buf in model.named_buffers():
if buf.is_floating_point() and buf.type() != "torch.cuda.FloatTensor": if buf.is_floating_point() and buf.type() != "torch.cuda.FloatTensor":
print("Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n" warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n" "When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format( "before passing it, no matter what optimization level you choose.".format(
name, buf.type())) name, buf.type()))
...@@ -77,7 +79,7 @@ def check_optimizers(optimizers): ...@@ -77,7 +79,7 @@ def check_optimizers(optimizers):
bad_optim_type = "apex.optimizers.FP16_Optimizer" bad_optim_type = "apex.optimizers.FP16_Optimizer"
if bad_optim_type is not None: if bad_optim_type is not None:
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) + raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) +
"The optimizer(s) passed to amp.initialize() should be bare \n" "The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n" "instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers (currently just FusedAdam, but FusedSGD will be added \n" "optimizers (currently just FusedAdam, but FusedSGD will be added \n"
"soon). You should not manually wrap your optimizer in either \n" "soon). You should not manually wrap your optimizer in either \n"
......
import torch import torch
from ._initialize import _initialize from ._initialize import _initialize
from ._amp_state import _amp_state from ._amp_state import _amp_state, warn_or_err
class Properties(object): class Properties(object):
...@@ -165,21 +165,6 @@ opt_levels = {"O3": O3(), ...@@ -165,21 +165,6 @@ opt_levels = {"O3": O3(),
"O1": O1(), "O1": O1(),
"O0": O0()} "O0": O0()}
def check_params_fp32(model):
for name, param in model.named_parameters():
if param.type() != "torch.cuda.FloatTensor":
print("Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))
for name, param in model.named_buffers():
if param.type() != "torch.cuda.FloatTensor":
print("Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))
# allow user to directly pass Properties struct as well? # allow user to directly pass Properties struct as well?
def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
...@@ -193,6 +178,8 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -193,6 +178,8 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
loss_scale=None,) loss_scale=None,)
""" """
if not enabled: if not enabled:
if "hard_override" in kwargs:
_amp_state.hard_override = kwargs["hard_override"]
_amp_state.opt_properties = Properties() _amp_state.opt_properties = Properties()
return models, optimizers return models, optimizers
...@@ -222,41 +209,43 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -222,41 +209,43 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
return _initialize(models, optimizers, _amp_state.opt_properties) return _initialize(models, optimizers, _amp_state.opt_properties)
def check_option_consistency(enabled=True, # TODO: is this necessary/useful?
opt_level=None, # def check_option_consistency(enabled=True,
cast_model_type=None, # opt_level=None,
patch_torch_functions=None, # cast_model_type=None,
keep_batchnorm_fp32=None, # patch_torch_functions=None,
master_weights=None, # keep_batchnorm_fp32=None,
loss_scale=None, # master_weights=None,
enable_ddp_interop=None): # loss_scale=None,
""" # enable_ddp_interop=None,
Utility function that enables users to quickly check if the option combination they intend # hard_override=False):
to use is permitted. ``check_option_consistency`` does not require models or optimizers # """
to be constructed, and can be called at any point in the script. ``check_option_consistency`` # Utility function that enables users to quickly check if the option combination they intend
is totally self-contained; it does not set any amp global state or affect anything outside # to use is permitted. ``check_option_consistency`` does not require models or optimizers
of itself. # to be constructed, and can be called at any point in the script. ``check_option_consistency``
""" # is totally self-contained; it does not set any amp global state or affect anything outside
# of itself.
if not enabled: # """
return #
# if not enabled:
if opt_level not in opt_levels: # return
raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.") #
else: # if opt_level not in opt_levels:
opt_properties = opt_levels[opt_level](Properties()) # raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
print("Selected optimization level {}", opt_levels[opt_level].brief) # else:
print("Defaults for this optimization level are:") # opt_properties = opt_levels[opt_level](Properties())
for k, v in opt_properties.options: # print("Selected optimization level {}", opt_levels[opt_level].brief)
print("{:22} : {}".format(k, v)) # print("Defaults for this optimization level are:")
# for k, v in opt_properties.options:
print("Processing user overrides (additional kwargs that are not None)...") # print("{:22} : {}".format(k, v))
for k, v in kwargs: #
if k not in amp_state.opt_properties.options: # print("Processing user overrides (additional kwargs that are not None)...")
raise RuntimeError("Unexpected kwarg {}".format(k)) # for k, v in kwargs:
if v is not None: # if k not in _amp_state.opt_properties.options:
setattr(opt_properties, k, v) # raise RuntimeError("Unexpected kwarg {}".format(k))
# if v is not None:
print("After processing overrides, optimization options are:") # setattr(opt_properties, k, v)
for k, v in opt_properties.options: #
print("{:22} : {}".format(k, v)) # print("After processing overrides, optimization options are:")
# for k, v in opt_properties.options:
# print("{:22} : {}".format(k, v))
Under construction...
# Mixed Precision ImageNet Training in PyTorch # Mixed Precision ImageNet Training in PyTorch
This example is based on [https://github.com/pytorch/examples/tree/master/imagenet](https://github.com/pytorch/examples/tree/master/imagenet). `main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/imagenet](https://github.com/pytorch/examples/tree/master/imagenet).
It implements mixed precision training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset. It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and VGG, on the ImageNet dataset, and illustrates use of the new Amp API along with command-line flags (forwarded to `amp.initialize`) to easily manipulate and switch between various pure and mixed precision training modes.
`main_amp.py` illustrates use of the new Amp API along with command-line flags (forwarded to `amp.initialize`) to easily manipulate and switch between various pure and mixed precision training modes. Three lines enable Amp:
```
# Added after model and optimizer construction
model, optimizer = amp.initialize(model, optimizer, flags...)
...
# loss.backward() changed to:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
```
Notice that with the new Amp API **you never need to explicitly convert your model, or the input data, to half().** With the new Amp API **you never need to explicitly convert your model, or the input data, to half().**
## Requirements ## Requirements
- Download the ImageNet dataset and move validation images to labeled subfolders - Download the ImageNet dataset and move validation images to labeled subfolders
- To do this, you can use the following script: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh - The following script may be helpful: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh
## Training ## Training
...@@ -30,7 +38,7 @@ CPU data loading bottlenecks. ...@@ -30,7 +38,7 @@ CPU data loading bottlenecks.
`O0` and `O3` can be told to use loss scaling via manual overrides, but using loss scaling with `O0` `O0` and `O3` can be told to use loss scaling via manual overrides, but using loss scaling with `O0`
(pure FP32 training) does not really make sense, and will trigger a warning. (pure FP32 training) does not really make sense, and will trigger a warning.
Softlink training and validation dataset into current directory: Softlink training and validation datasets into the current directory:
``` ```
$ ln -sf /data/imagenet/train-jpeg/ train $ ln -sf /data/imagenet/train-jpeg/ train
$ ln -sf /data/imagenet/val-jpeg/ val $ ln -sf /data/imagenet/val-jpeg/ val
...@@ -38,7 +46,7 @@ $ ln -sf /data/imagenet/val-jpeg/ val ...@@ -38,7 +46,7 @@ $ ln -sf /data/imagenet/val-jpeg/ val
### Summary ### Summary
Amp enables easy experimentation with various pure and mixed precision options. Amp allows easy experimentation with various pure and mixed precision options.
``` ```
$ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./ $ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./
$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./ $ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./
...@@ -133,6 +141,16 @@ With Apex DDP, it uses only the current device by default). ...@@ -133,6 +141,16 @@ With Apex DDP, it uses only the current device by default).
The choice of DDP wrapper (Torch or Apex) is orthogonal to the use of Amp and other Apex tools. It is safe to use `apex.amp` with either `torch.nn.parallel.DistributedDataParallel` or `apex.parallel.DistributedDataParallel`. In the future, I may add some features that permit optional tighter integration between `Amp` and `apex.parallel.DistributedDataParallel` for marginal performance benefits, but currently, there's no compelling reason to use Apex DDP versus Torch DDP for most models. The choice of DDP wrapper (Torch or Apex) is orthogonal to the use of Amp and other Apex tools. It is safe to use `apex.amp` with either `torch.nn.parallel.DistributedDataParallel` or `apex.parallel.DistributedDataParallel`. In the future, I may add some features that permit optional tighter integration between `Amp` and `apex.parallel.DistributedDataParallel` for marginal performance benefits, but currently, there's no compelling reason to use Apex DDP versus Torch DDP for most models.
To use DDP with `apex.amp`, the only gotcha is that
```
model, optimizer = amp.initialize(model, optimizer, flags...)
```
must precede
```
model = DDP(model)
```
If DDP wrapping occurs before `amp.initialize`, `amp.initialize` will raise an error.
With both Apex DDP and Torch DDP, you must also call `torch.cuda.set_device(args.local_rank)` within With both Apex DDP and Torch DDP, you must also call `torch.cuda.set_device(args.local_rank)` within
each process prior to initializing your model or any other tensors. each process prior to initializing your model or any other tensors.
More information can be found in the docs for the More information can be found in the docs for the
......
...@@ -12,13 +12,16 @@ TORCH_MAJOR = int(torch.__version__.split('.')[0]) ...@@ -12,13 +12,16 @@ TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 0 and TORCH_MINOR < 4: if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
raise RuntimeError("APEx requires Pytorch 0.4 or newer.\n" + raise RuntimeError("Apex requires Pytorch 0.4 or newer.\n" +
"The latest stable release can be obtained from https://pytorch.org/") "The latest stable release can be obtained from https://pytorch.org/")
cmdclass = {} cmdclass = {}
ext_modules = [] ext_modules = []
if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0:
raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}".format(torch.__version))
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
...@@ -34,7 +37,7 @@ if "--cuda_ext" in sys.argv: ...@@ -34,7 +37,7 @@ if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None:
print("Warning: nvcc is not available. Ignoring --cuda-ext") raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
CUDAExtension(name='amp_C', CUDAExtension(name='amp_C',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment