Commit 7bef81f7 authored by Pruthvi Madugundu's avatar Pruthvi Madugundu
Browse files

Updated the handling of CUDAGeneratorImpl.h to new path

parent 980d5f44
......@@ -30,7 +30,7 @@
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#if !defined(NEW_GENERATOR_PATH)
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
......
#include <ATen/ATen.h>
#ifdef OLD_GENERATOR_PATH
#if !defined(NEW_GENERATOR_PATH)
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
......
......@@ -3,7 +3,7 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h>
#ifdef OLD_GENERATOR_PATH
#if !defined(NEW_GENERATOR_PATH)
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
......
......@@ -5,7 +5,7 @@
#include <torch/extension.h>
#include <ATen/AccumulateType.h>
#ifdef OLD_GENERATOR_PATH
#if !defined(NEW_GENERATOR_PATH)
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
......
......@@ -366,8 +366,8 @@ if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "cuda", "CUDAGeneratorImpl.h")):
generator_flag = ["-DNEW_GENERATOR_PATH"]
if "--fast_layer_norm" in sys.argv:
sys.argv.remove("--fast_layer_norm")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment