Unverified Commit 5cb187f3 authored by Kevin Stephano's avatar Kevin Stephano Committed by GitHub
Browse files

Update Softmax in multihead attention to use the Current Cuda Stream instead...


Update Softmax in multihead attention to use the Current Cuda Stream instead of the Default Cuda Stream. (#843)

* Adding C++ Multihead Attention implementation to contrib.

* Add reference test that at least works for forward.

* Remove CublasLt support from multihead attention.

* Add new Python version of self attention.

* Update python model of MHA with backward pass.

* Fixed Output Linear connection in MHA.

* Clean up compiles and add documentation to PySelfAttention.

* Add Encdec Python version of multihead attention.  Cleanup files.

* Tests for self and encdec multihead attention.

* Add reference pytorch implementation of attention with norm and add.

* Add cutlass branch definition.

* Add cutlass download to compile.

* Add norm/add tests.

* Add biases to pytorch python versions.

* Add tests and fix issues with python version of attention masking.

* Create README.md

* Update README.md

* Update README.md

* Update perf test parameters.

* Update README.md

* Update README.md

* Update README.md

* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Fix matmul1 output tensor size.  Fix tests that missed issue.

* Allow for Z dimensions of 64K and greater on batched GEMMs.

* remove redundant imports

* general cleanup, remove deprecated or unused functions

* Update Multihead Attention's softmax to use the Current Stream instead of the default stream.

* Fix setup.py that got messed up in merge with upstream.

* Update Multihead Attention strided batched gemms to use the current stream instead of the default.
Co-authored-by: default avatarpbialecki <pbialecki@nvidia.com>
parent 4a1aa97e
...@@ -240,7 +240,7 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i ...@@ -240,7 +240,7 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// launch // launch
kernel<<<blocks, threads>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
return true; return true;
} }
return false; return false;
...@@ -464,7 +464,7 @@ bool dispatch_masked_softmax(output_t *dst, const input_t *src, const uint8_t *p ...@@ -464,7 +464,7 @@ bool dispatch_masked_softmax(output_t *dst, const input_t *src, const uint8_t *p
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// launch // launch
kernel<<<blocks, threads>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride); kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
return true; return true;
} }
return false; return false;
...@@ -687,7 +687,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8 ...@@ -687,7 +687,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// launch // launch
kernel<<<blocks, threads>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, mod_seq_len); kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, pad_mask, batch_count, softmax_elements_stride, softmax_elements, mod_seq_len);
return true; return true;
} }
return false; return false;
...@@ -873,7 +873,7 @@ bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const ...@@ -873,7 +873,7 @@ bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// launch // launch
kernel<<<blocks, threads>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements); kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, batch_count, softmax_elements_stride, softmax_elements);
return true; return true;
} }
return false; return false;
...@@ -1062,7 +1062,7 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, ...@@ -1062,7 +1062,7 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// launch // launch
kernel<<<blocks, threads>>>(grad_input, grad, output, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride); kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, pad_mask, batch_count, softmax_elements_stride, softmax_elements, pad_batch_stride);
return true; return true;
} }
return false; return false;
......
...@@ -33,8 +33,10 @@ void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m, ...@@ -33,8 +33,10 @@ void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m,
float beta, half *c, long ldc, long strideC, long batchCount, cublasGemmAlgo_t algo=CUBLAS_GEMM_DEFAULT_TENSOR_OP) { float beta, half *c, long ldc, long strideC, long batchCount, cublasGemmAlgo_t algo=CUBLAS_GEMM_DEFAULT_TENSOR_OP) {
cublasOperation_t opa = convertTransToCublasOperation(transa); cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb); cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
...@@ -131,7 +133,7 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, ...@@ -131,7 +133,7 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object."); AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel. // Launch the CUTLASS GEMM kernel.
THCudaCheck(Gemm::launch(params)); THCudaCheck(Gemm::launch(params, stream));
// Update batched GEMM params based on completed work // Update batched GEMM params based on completed work
batchesLeft = batchesLeft - iterBatchCount; batchesLeft = batchesLeft - iterBatchCount;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment