Commit db7007ae authored by flyingdown's avatar flyingdown
Browse files

modify rocblas_gemm_ex's compute_type to rocblas_datatype_f16_r for fp16

parent 32ab028c
......@@ -42,9 +42,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -110,7 +113,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -136,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -239,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -278,9 +281,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -352,7 +358,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -378,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -513,7 +519,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -539,7 +545,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -565,7 +571,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -591,7 +597,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......
......@@ -51,9 +51,12 @@ std::vector<torch::Tensor> fwd_cuda(
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -137,7 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -163,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -266,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -330,9 +333,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -416,7 +422,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -442,7 +448,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -578,7 +584,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -604,7 +610,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -630,7 +636,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -656,7 +662,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......
......@@ -37,10 +37,14 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta_zero = 0.0;
// const float beta_one = 1.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta_zero = 0.0;
const half beta_one = 1.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -106,7 +110,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -203,7 +207,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -231,9 +235,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -301,7 +308,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -327,7 +334,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......
......@@ -36,10 +36,14 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta_zero = 0.0;
// const float beta_one = 1.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta_zero = 0.0;
const half beta_one = 1.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -104,7 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -209,7 +213,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -237,9 +241,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -307,7 +314,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -333,7 +340,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......
......@@ -36,9 +36,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -102,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -205,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -233,9 +236,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -303,7 +309,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -329,7 +335,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -464,7 +470,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -490,7 +496,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......
......@@ -40,9 +40,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -124,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -228,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -280,9 +283,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -361,7 +367,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -387,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -523,7 +529,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -550,7 +556,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......
......@@ -42,9 +42,48 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
}
}
// void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
// cublasOperation_t opa = convertTransToCublasOperation(transa);
// cublasOperation_t opb = convertTransToCublasOperation(transb);
// cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
// cublasSetStream(handle, stream);
// float fAlpha = alpha;
// float fBeta = beta;
// //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
// opa, opb, (int)m, (int)n, (int)k,
// (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
// b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
// (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
// d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
// (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
// }
// void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
// auto stream = c10::cuda::getCurrentCUDAStream();
// if ( (transa == 't') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 't') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else {
// AT_ASSERTM(false, "TransA and TransB are invalid");
// }
// }
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
half beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
......@@ -56,16 +95,16 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
(void*)&alpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
(void*)&beta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
(int)batchCount, rocblas_datatype_f16_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
}
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
half beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
......
......@@ -151,6 +151,7 @@ class EncdecMultiheadAttn(nn.Module):
self.dropout,
)
if is_training:
print('default:', outputs)
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else:
outputs = outputs + query
......
......@@ -164,6 +164,8 @@ cublasStatus_t gemm_bias(
at::Half* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
half hAlpha = __float2half(*alpha);
half hBeta = __float2half(*beta);
return rocblas_gemm_ex(
handle,
transa,
......@@ -171,21 +173,21 @@ cublasStatus_t gemm_bias(
m,
n,
k,
alpha,
/* alpha */ &hAlpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
/* beta */ &hBeta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard,
0,
0);
......
......@@ -211,6 +211,8 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
half hAlpha = __float2half(*alpha);
half hBeta = __float2half(*beta);
return rocblas_gemm_ex(
handle,
transa,
......@@ -218,21 +220,21 @@ cublasStatus_t mlp_gemm(
m,
n,
k,
alpha,
/* alpha */ &hAlpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
/* beta */ &hBeta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard,
0,
flag);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment