Commit f4ad42c1 authored by lcskrishna's avatar lcskrishna
Browse files

fix compile args for multi-tensor extension

parent 91003340
......@@ -9,7 +9,6 @@
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
......
......@@ -161,7 +161,8 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_adagrad.cu',
'csrc/multi_tensor_novograd.cu',
'csrc/multi_tensor_lamb.cu'],
extra_compile_args = nvcc_args_multi_tensor if not is_rocm_pytorch else hipcc_args_multi_tensor))
extra_compile_args = { 'cxx' : ['-O3'] + version_dependent_macros,
'nvcc': nvcc_args_multi_tensor if not is_rocm_pytorch else hipcc_args_multi_tensor}))
print ("INFO: Builidng syncbn extension.")
ext_modules.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment