Unverified Commit 80b90b9d authored by ptrblck's avatar ptrblck Committed by GitHub
Browse files

Fix deprecated calls in multihead_attn and ninja build failure (#746)



* disable ninja for multihead_attn

* fix getCurrentStream in multihead_attn
Co-authored-by: default avatarpbialecki <pbialecki@nvidia.com>
parent 20d00ab1
...@@ -147,7 +147,7 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, ...@@ -147,7 +147,7 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, long batchCount) { float beta, half *c, long ldc, long strideC, long batchCount) {
cudaStream_t stream = THCState_getCurrentStream(state); auto stream = c10::cuda::getCurrentCUDAStream();
//printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta); //printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
if ( (transa == 't') && (transb == 'n') ) { if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); }
......
...@@ -203,7 +203,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -203,7 +203,7 @@ if "--fast_multihead_attn" in sys.argv:
sys.argv.remove("--fast_multihead_attn") sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment