Commit 004d6f9b authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Stricter conditions for aggressive PTX instructions

parent 7de7464e
...@@ -303,6 +303,7 @@ For two-micro-batch overlapping, you can refer to the following figure. With our ...@@ -303,6 +303,7 @@ For two-micro-batch overlapping, you can refer to the following figure. With our
- [ ] Internode kernels - [ ] Internode kernels
- [ ] Low-latency kernels - [ ] Low-latency kernels
- [ ] SM-free kernels and refactors - [ ] SM-free kernels and refactors
- [ ] Fully remove undefined-behavior PTX instructions
## Notices ## Notices
......
...@@ -145,7 +145,7 @@ __device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) { ...@@ -145,7 +145,7 @@ __device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) {
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS #ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
#define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B" #define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B"
#else #else
#define LD_NC_FUNC "ld.volatile.global" #define LD_NC_FUNC "ld.volatile.global.L2::256B"
#endif #endif
// `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS // `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS
......
...@@ -42,10 +42,6 @@ if __name__ == '__main__': ...@@ -42,10 +42,6 @@ if __name__ == '__main__':
# Disable internode and low-latency kernels # Disable internode and low-latency kernels
assert disable_nvshmem assert disable_nvshmem
# Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate`
assert int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', 1)) == 1
os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1'
else: else:
# Prefer H800 series # Prefer H800 series
os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '9.0') os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '9.0')
...@@ -53,6 +49,11 @@ if __name__ == '__main__': ...@@ -53,6 +49,11 @@ if __name__ == '__main__':
# CUDA 12 flags # CUDA 12 flags
nvcc_flags.extend(['-rdc=true', '--ptxas-options=--register-usage-level=10']) nvcc_flags.extend(['-rdc=true', '--ptxas-options=--register-usage-level=10'])
# Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate`
if os.environ['TORCH_CUDA_ARCH_LIST'].strip() != '9.0':
assert int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', 1)) == 1
os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1'
# Disable aggressive PTX instructions # Disable aggressive PTX instructions
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')): if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')):
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS') cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment