Commit 1fd257e2 authored by Abhishree's avatar Abhishree
Browse files

Enable the following modules in apex/contrib:

1) multihead_attn
2) xentropy
3) fused_adam and distributed_fused_adam
parent 297ab210
...@@ -4,3 +4,6 @@ build ...@@ -4,3 +4,6 @@ build
docs/build docs/build
*~ *~
__pycache__ __pycache__
*.hip
*_hip.*
*hip*
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -192,7 +192,7 @@ std::vector<torch::Tensor> bwd( ...@@ -192,7 +192,7 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."); m.def("backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -21,7 +25,7 @@ extern THCState *state; ...@@ -21,7 +25,7 @@ extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace encdec_norm_add { namespace encdec_norm_add {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -95,7 +99,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -95,7 +99,6 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
...@@ -109,7 +112,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -109,7 +112,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -117,21 +120,26 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -117,21 +120,26 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()), //static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, c_type,
output_lin_q_dim,
q_lin_results_ptr,
d_type,
output_lin_q_dim, output_lin_q_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -139,18 +147,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -139,18 +147,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
k_lin_results_ptr, k_lin_results_ptr,
CUDA_R_16F, c_type,
output_lin_kv_dim, output_lin_kv_dim,
CUDA_R_32F, k_lin_results_ptr,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
output_lin_kv_dim,
compute_type,
algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_t, a_layout_t,
...@@ -169,6 +181,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -169,6 +181,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
...@@ -231,10 +246,13 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -231,10 +246,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -242,18 +260,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -242,18 +260,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_lin_results.data_ptr()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); d_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
compute_type,
algo,
solution_index,
flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
...@@ -272,8 +294,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -272,8 +294,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q); total_tokens_q);
} }
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_mean,
...@@ -367,8 +387,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -367,8 +387,6 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
...@@ -378,7 +396,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -378,7 +396,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -386,20 +404,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -386,20 +404,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -407,17 +430,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -407,17 +430,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -437,6 +465,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -437,6 +465,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
...@@ -457,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -457,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
...@@ -496,6 +530,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -496,6 +530,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim_q, lead_dim_q,
batch_stride_q, batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
...@@ -516,10 +553,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -516,10 +553,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches); attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -527,22 +567,26 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -527,22 +567,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()), //static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()), static_cast<void*>(input_lin_q_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_lin_q_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); d_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -550,20 +594,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -550,20 +594,25 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()), static_cast<void*>(input_weight_q_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -571,21 +620,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -571,21 +620,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); algo,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); solution_index,
flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -593,17 +646,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -593,17 +646,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv, batches_kv,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()), static_cast<void*>(input_weight_kv_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half,float>(
...@@ -622,7 +680,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -622,7 +680,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
template<typename U> __device__ template<typename U> __device__
void cuWelfordOnlineSum( void cuWelfordOnlineSum(
const U curr, const U curr,
...@@ -84,9 +85,9 @@ void cuWelfordMuSigma2( ...@@ -84,9 +85,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31; int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB); U muB = WARP_SHFL(mu, srcLaneB, 32);
U countB = WARP_SHFL(count, srcLaneB); U countB = WARP_SHFL(count, srcLaneB, 32);
U sigma2B = WARP_SHFL(sigma2, srcLaneB); U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
} }
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
...@@ -122,8 +123,8 @@ void cuWelfordMuSigma2( ...@@ -122,8 +123,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/U(n2); sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0); mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/U(n2), 0); sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32);
} }
} }
} }
...@@ -180,9 +181,9 @@ void cuWelfordMuSigma2( ...@@ -180,9 +181,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31; int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB); float muB = WARP_SHFL(mu, srcLaneB, 32);
float countB = WARP_SHFL(count, srcLaneB); float countB = WARP_SHFL(count, srcLaneB, 32);
float sigma2B = WARP_SHFL(sigma2, srcLaneB); float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
} }
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
...@@ -218,8 +219,8 @@ void cuWelfordMuSigma2( ...@@ -218,8 +219,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/float(n2); sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0); mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/float(n2), 0); sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32);
} }
} }
} }
...@@ -227,9 +228,19 @@ void cuWelfordMuSigma2( ...@@ -227,9 +228,19 @@ void cuWelfordMuSigma2(
template<typename U> U rsqrt(U v) { template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v); return U(1) / sqrt(v);
} }
//template<> float rsqrt(float v) {
// return rsqrtf(v);
//}
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
return rsqrtf(v);
}
#else
template<> float rsqrt(float v) { template<> float rsqrt(float v) {
return rsqrtf(v); return rsqrtf(v);
} }
#endif
template<> double rsqrt(double v) { template<> double rsqrt(double v) {
return rsqrt(v); return rsqrt(v);
} }
...@@ -290,7 +301,7 @@ void cuApplyLayerNorm( ...@@ -290,7 +301,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
// 2) Tensors are contiguous // 2) Tensors are contiguous
// //
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
U mu,sigma2; U mu,sigma2;
...@@ -529,7 +540,7 @@ void cuComputeGradInput( ...@@ -529,7 +540,7 @@ void cuComputeGradInput(
const T* gamma, const T* gamma,
T* grad_input) T* grad_input)
{ {
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0); U sum_loss1 = U(0);
U sum_loss2 = U(0); U sum_loss2 = U(0);
const U c_mean = mean[i1]; const U c_mean = mean[i1];
...@@ -574,8 +585,8 @@ void cuComputeGradInput( ...@@ -574,8 +585,8 @@ void cuComputeGradInput(
} }
// intra-warp reductions // intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) { for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
} }
// inter-warp reductions // inter-warp reductions
if (blockDim.y > 1) { if (blockDim.y > 1) {
......
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace multihead_attn { namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -167,7 +167,7 @@ std::vector<torch::Tensor> bwd( ...@@ -167,7 +167,7 @@ std::vector<torch::Tensor> bwd(
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward."); m.def("forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward."); m.def("backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
} }
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -21,7 +25,7 @@ extern THCState *state; ...@@ -21,7 +25,7 @@ extern THCState *state;
namespace multihead_attn { namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
namespace cublas_gemmex { namespace rocblas_gemmex {
std::vector<torch::Tensor> fwd_cuda( std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask, bool use_time_mask,
...@@ -88,7 +92,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -88,7 +92,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
...@@ -102,7 +106,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -102,7 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -110,18 +114,23 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -110,18 +114,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
//static_cast<const void*>(inputs.data_ptr()), //static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
CUDA_R_16F, c_type,
output_lin_dim,
q_lin_results_ptr,
d_type,
output_lin_dim, output_lin_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -141,6 +150,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -141,6 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
...@@ -203,10 +215,13 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -203,10 +215,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -214,17 +229,23 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -214,17 +229,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(output_lin_results.data_ptr()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); d_type,
embed_dim,
compute_type,
algo,
solution_index,
flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
...@@ -243,8 +264,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -243,8 +264,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens); total_tokens);
} }
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_mean,
...@@ -327,8 +346,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -327,8 +346,6 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
...@@ -338,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -338,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -346,20 +363,25 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -346,20 +363,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -367,17 +389,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -367,17 +389,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, b_type,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
...@@ -397,6 +424,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -397,6 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
...@@ -417,6 +447,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -417,6 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
...@@ -456,6 +489,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -456,6 +489,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
...@@ -476,10 +512,13 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -476,10 +512,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -487,22 +526,26 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -487,22 +526,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()), //static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()), static_cast<void*>(input_lin_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim, embed_dim,
CUDA_R_32F, static_cast<void*>(input_lin_grads.data_ptr()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP)); d_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); embed_dim,
compute_type,
algo,
solution_index,
flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, THCublasCheck(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -511,17 +554,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -511,17 +554,22 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()), //static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
CUDA_R_16F, a_type,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F, b_type,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
CUDA_R_16F, c_type,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
d_type,
embed_dim, embed_dim,
CUDA_R_32F, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); algo,
solution_index,
flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half,float>(
...@@ -540,7 +588,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -540,7 +588,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
...@@ -551,6 +598,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -551,6 +598,6 @@ std::vector<torch::Tensor> bwd_cuda(
}; };
} }
} // end namespace cublas_gemmex } // end namespace rocblas_gemmex
} // end namespace self_norm_add } // end namespace self_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -234,12 +234,12 @@ void fused_adam_cuda( ...@@ -234,12 +234,12 @@ void fused_adam_cuda(
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(), p.DATA_PTR<accscalar_t>(),
...@@ -308,12 +308,12 @@ void fused_adam_cuda_mt( ...@@ -308,12 +308,12 @@ void fused_adam_cuda_mt(
size_t tl_sz = tensor_lists.size(); size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) { if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half || tensor_lists[3][0].scalar_type() == at::ScalarType::BFloat16) {
//alher values should be fp32 for half gradients //alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type //dich is done on the gradient type
if (tl_sz == 5) { if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
...@@ -330,7 +330,7 @@ void fused_adam_cuda_mt( ...@@ -330,7 +330,7 @@ void fused_adam_cuda_mt(
decay); decay);
); );
} else { } else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
...@@ -846,13 +846,13 @@ void fused_reversible_adam_cuda( ...@@ -846,13 +846,13 @@ void fused_reversible_adam_cuda(
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) { if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) {
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(), p.DATA_PTR<accscalar_t>(),
...@@ -871,7 +871,7 @@ void fused_reversible_adam_cuda( ...@@ -871,7 +871,7 @@ void fused_reversible_adam_cuda(
); );
} else { } else {
AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type"); AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type");
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_e5m2_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_e5m2_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>( reversible_adam_cuda_kernel<accscalar_t, scalar_t_0, uint8_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.DATA_PTR<accscalar_t>(), p.DATA_PTR<accscalar_t>(),
...@@ -991,12 +991,12 @@ void fused_maybe_adam_undo_cuda( ...@@ -991,12 +991,12 @@ void fused_maybe_adam_undo_cuda(
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.scalar_type() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( maybe_adam_undo_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL, overflow_flag.numel() ? overflow_flag.DATA_PTR<int>() : NULL,
......
...@@ -187,7 +187,7 @@ void multi_tensor_fused_adam_cuda( ...@@ -187,7 +187,7 @@ void multi_tensor_fused_adam_cuda(
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tl_sz == 5) { if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
...@@ -206,7 +206,7 @@ void multi_tensor_fused_adam_cuda( ...@@ -206,7 +206,7 @@ void multi_tensor_fused_adam_cuda(
(adamMode_t) mode); (adamMode_t) mode);
); );
} else { } else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
......
...@@ -586,7 +586,7 @@ std::vector<Tensor> host_softmax_xentropy( ...@@ -586,7 +586,7 @@ std::vector<Tensor> host_softmax_xentropy(
const Tensor & labels_, const Tensor & labels_,
const float smoothing, const float smoothing,
const bool half_to_float){ const bool half_to_float){
if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half,"conversion is supported for Half type only"); if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half || input_.type().scalarType() == ScalarType::BFloat16,"conversion is supported for Half and BFloat16 type only");
AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,"Label type should be CUDA Long"); AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,"Label type should be CUDA Long");
auto input = input_.contiguous(); auto input = input_.contiguous();
...@@ -617,7 +617,7 @@ std::vector<Tensor> host_softmax_xentropy( ...@@ -617,7 +617,7 @@ std::vector<Tensor> host_softmax_xentropy(
dim3 grid(outer_size); dim3 grid(outer_size);
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "host_softmax_xentropy", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(input.scalar_type(), 0, "host_softmax_xentropy",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0); const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size); dim3 block = SoftMax_getBlockSize(ILP, dim_size);
...@@ -685,7 +685,7 @@ Tensor host_softmax_xentropy_backward( ...@@ -685,7 +685,7 @@ Tensor host_softmax_xentropy_backward(
dim3 grid(outer_size); dim3 grid(outer_size);
DISPATCH_FLOAT_AND_HALF(gI.scalar_type(), 0, "host_softmax_xentropy_backward", DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
using accscalar_t = acc_type<scalar_t_0, true>; using accscalar_t = acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0); const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size); dim3 block = SoftMax_getBlockSize(ILP, dim_size);
...@@ -724,7 +724,7 @@ at::Tensor softmax_xentropy_backward_cuda( ...@@ -724,7 +724,7 @@ at::Tensor softmax_xentropy_backward_cuda(
const float smoothing) { const float smoothing) {
bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType(); bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType();
if (half_to_float) { if (half_to_float) {
AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && logits.type().scalarType() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && (logits.type().scalarType() == ScalarType::Half || logits.type().scalarType() == ScalarType::BFloat16)), "expected input and grad types to match, or input to be at::Half or at::Bfloat16 and grad to be at::Float");
} }
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float); return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float);
} }
...@@ -263,6 +263,6 @@ class EncdecAttnFunc(torch.autograd.Function): ...@@ -263,6 +263,6 @@ class EncdecAttnFunc(torch.autograd.Function):
input_q_grads, input_kv_grads, \ input_q_grads, input_kv_grads, \
input_weight_q_grads, input_weight_kv_grads, output_weight_grads, \ input_weight_q_grads, input_weight_kv_grads, output_weight_grads, \
input_bias_grads_q, input_bias_grads_kv, output_bias_grads, \ input_bias_grads_q, input_bias_grads_kv, output_bias_grads, \
None, None None, None, None
encdec_attn_func = EncdecAttnFunc.apply encdec_attn_func = EncdecAttnFunc.apply
...@@ -9,7 +9,7 @@ class FastSelfAttnNormAddFunc(torch.autograd.Function): ...@@ -9,7 +9,7 @@ class FastSelfAttnNormAddFunc(torch.autograd.Function):
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
print("---use_mask-----",use_mask)
lyr_nrm_results, \ lyr_nrm_results, \
lyr_nrm_mean, \ lyr_nrm_mean, \
lyr_nrm_invvar, \ lyr_nrm_invvar, \
......
...@@ -230,6 +230,6 @@ class SelfAttnFunc(torch.autograd.Function): ...@@ -230,6 +230,6 @@ class SelfAttnFunc(torch.autograd.Function):
input_grads, \ input_grads, \
input_weight_grads, output_weight_grads, \ input_weight_grads, output_weight_grads, \
input_bias_grads, output_bias_grads, \ input_bias_grads, output_bias_grads, \
None, None None, None, None
self_attn_func = SelfAttnFunc.apply self_attn_func = SelfAttnFunc.apply
...@@ -34,6 +34,12 @@ def check_if_rocm_pytorch(): ...@@ -34,6 +34,12 @@ def check_if_rocm_pytorch():
IS_ROCM_PYTORCH = check_if_rocm_pytorch() IS_ROCM_PYTORCH = check_if_rocm_pytorch()
if IS_ROCM_PYTORCH:
rocm_include_dirs = ["/opt/rocm/include/hiprand", "/opt/rocm/include/rocrand"]
else:
rocm_include_dirs = []
include_dirs=[os.path.join(this_dir, 'csrc')] + rocm_include_dirs
if not torch.cuda.is_available() and not IS_ROCM_PYTORCH: if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
# https://github.com/NVIDIA/apex/issues/486 # https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
...@@ -144,17 +150,18 @@ if "--distributed_adam" in sys.argv: ...@@ -144,17 +150,18 @@ if "--distributed_adam" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
nvcc_args_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_adam = ['-O3'] + version_dependent_macros
ext_modules.append( ext_modules.append(
CUDAExtension(name='distributed_adam_cuda', CUDAExtension(name='distributed_adam_cuda',
sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp', sources=['./apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp',
'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'], './apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/optimizers/'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3', 'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam}))
'--use_fast_math'] + version_dependent_macros}))
if "--distributed_lamb" in sys.argv: if "--distributed_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
...@@ -273,9 +280,9 @@ if "--xentropy" in sys.argv: ...@@ -273,9 +280,9 @@ if "--xentropy" in sys.argv:
print ("INFO: Building the xentropy extension.") print ("INFO: Building the xentropy extension.")
ext_modules.append( ext_modules.append(
CUDAExtension(name='xentropy_cuda', CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp', sources=['./apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'], './apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/xentropy/'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
...@@ -295,9 +302,9 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -295,9 +302,9 @@ if "--deprecated_fused_adam" in sys.argv:
hipcc_args_fused_adam = ['-O3'] + version_dependent_macros hipcc_args_fused_adam = ['-O3'] + version_dependent_macros
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_adam_cuda', CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp', sources=['./apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'], './apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/optimizers/'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam})) 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))
...@@ -368,17 +375,21 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -368,17 +375,21 @@ if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False) cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
if not IS_ROCM_PYTORCH:
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
nvcc_args_mha = ['-O3', '-gencode', 'arch=compute_70,code=sm_70', '-I./apex/contrib/csrc/multihead_attn/cutlass/', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag
hipcc_args_mha = ['-O3', '-I./apex/contrib/csrc/multihead_attn/cutlass/', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout', CUDAExtension(name='fast_additive_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp', sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp',
...@@ -446,17 +457,11 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -446,17 +457,11 @@ if "--fast_multihead_attn" in sys.argv:
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add', CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp', sources=['./apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'], './apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/multihead_attn/'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn', CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
...@@ -472,17 +477,11 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -472,17 +477,11 @@ if "--fast_multihead_attn" in sys.argv:
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add', CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp', sources=['./apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'], './apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
include_dirs=include_dirs + [this_dir + '/apex/contrib/csrc/multihead_attn/'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
setup( setup(
name='apex', name='apex',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment