Commit 6e39bee3 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' of https://github.com/NVIDIA/apex

parents bfa3e0ee 251cddaf
...@@ -191,11 +191,11 @@ def rnn_cast(backend, fn, verbose=False): ...@@ -191,11 +191,11 @@ def rnn_cast(backend, fn, verbose=False):
# 2) Inputs: either a tuple (for LSTM) or single tensor # 2) Inputs: either a tuple (for LSTM) or single tensor
if isinstance(hiddens, tuple): if isinstance(hiddens, tuple):
new_args.append(tuple(cast_fn(x) for x in hiddens)) new_args.append(tuple(cast_fn(x) for x in hiddens))
elif utils.is_fp_tensor(hidden): elif utils.is_fp_tensor(hiddens):
new_args.append(cast_fn(hidden)) new_args.append(cast_fn(hiddens))
else: else:
# Hidden can, in principle, be `None` -- pass through # Hiddens can, in principle, be `None` -- pass through
new_args.append(hidden) new_args.append(hiddens)
# 3) Batch sizes (0.4 or later only) # 3) Batch sizes (0.4 or later only)
if len(fargs) == 4: if len(fargs) == 4:
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
// #include "ATen/AccumulateType.h" // #include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDATensorMethods.cuh" #include "ATen/cuda/CUDATensorMethods.cuh"
#include "ATen/cuda/CUDATypeConversion.cuh" // #include "ATen/cuda/CUDATypeConversion.cuh"
// #include <THC/THCTensorMathReduce.cuh> // #include <THC/THCTensorMathReduce.cuh>
#include <THC/THCGeneral.h> #include <THC/THCGeneral.h>
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#endif #endif
#include "ATen/cuda/CUDATensorMethods.cuh" #include "ATen/cuda/CUDATensorMethods.cuh"
#include "ATen/cuda/CUDATypeConversion.cuh" // #include "ATen/cuda/CUDATypeConversion.cuh"
// #include <THC/THCTensorMathReduce.cuh> // #include <THC/THCTensorMathReduce.cuh>
template template
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#endif #endif
#include "ATen/cuda/CUDATensorMethods.cuh" #include "ATen/cuda/CUDATensorMethods.cuh"
#include "ATen/cuda/CUDATypeConversion.cuh" // #include "ATen/cuda/CUDATypeConversion.cuh"
// #include <THC/THCTensorMathReduce.cuh> // #include <THC/THCTensorMathReduce.cuh>
template template
......
import torch.cuda import torch.cuda
import ctypes
import os import os
import re import re
import subprocess import subprocess
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment