Commit 9d3cb277 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.9.0

parent f30c71d6
#!/bin/bash
# https://github.com/pytorch/pytorch/commit/d2e16dd888a9b5fd55bd475d4fcffb70f388d4f0
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "${TORCH_VERSION}" = "1.7.0" ]; then
echo "Fix nvcc for PyTorch"
sed -i.bak -e 's/CONSTEXPR_EXCEPT_WIN_CUDA/const/g' /c/tools/miniconda3/envs/test/lib/site-packages/torch/include/torch/csrc/jit/api/module.h
sed -i.bak -e 's/return \*(this->value)/return \*((type\*)this->value)/g' /c/tools/miniconda3/envs/test/lib/site-packages/torch/include/pybind11/cast.h
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "${TORCH_VERSION}" = "1.7.0" ]; then
echo "Fix nvcc for PyTorch 1.7.0"
sed -i.bak '/static constexpr Symbol Kind/d' /c/tools/miniconda3/envs/test/lib/site-packages/torch/include/torch/csrc/jit/ir/ir.h
fi
......@@ -34,7 +34,8 @@ def get_extensions():
extra_link_args = ['-s']
info = parallel_info()
if 'backend: OpenMP' in info and 'OpenMP not found' not in info:
if ('backend: OpenMP' in info and 'OpenMP not found' not in info
and sys.platform != 'darwin'):
extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP']
if sys.platform == 'win32':
extra_compile_args['cxx'] += ['/openmp']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment