Commit db7007ae authored by flyingdown's avatar flyingdown
Browse files

modify rocblas_gemm_ex's compute_type to rocblas_datatype_f16_r for fp16

parent 32ab028c
...@@ -42,9 +42,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -42,9 +42,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int batch_stride_q = head_dim; const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim; const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -110,7 +113,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -110,7 +113,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_q_dim, output_lin_q_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -136,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -136,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
k_lin_results_ptr, k_lin_results_ptr,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_kv_dim, output_lin_kv_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -239,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -239,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -278,9 +281,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -278,9 +281,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int batch_stride_q = head_dim; const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim; const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
...@@ -352,7 +358,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -352,7 +358,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -378,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -378,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -513,7 +519,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -513,7 +519,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_q_grads.data_ptr()), static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -539,7 +545,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -539,7 +545,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_q_grads.data_ptr()), static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -565,7 +571,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -565,7 +571,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -591,7 +597,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -591,7 +597,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_kv_grads.data_ptr()), static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
......
...@@ -51,9 +51,12 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -51,9 +51,12 @@ std::vector<torch::Tensor> fwd_cuda(
const int batch_stride_q = head_dim; const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim; const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -137,7 +140,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -137,7 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim, output_lin_q_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -163,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -163,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
k_lin_results_ptr, k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim, output_lin_kv_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -266,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -266,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -330,9 +333,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -330,9 +333,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int batch_stride_q = head_dim; const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim; const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
...@@ -416,7 +422,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -416,7 +422,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -442,7 +448,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -442,7 +448,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -578,7 +584,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -578,7 +584,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_lin_q_grads.data_ptr()), static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -604,7 +610,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -604,7 +610,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_q_grads.data_ptr()), static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -630,7 +636,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -630,7 +636,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -656,7 +662,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -656,7 +662,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_kv_grads.data_ptr()), static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
......
...@@ -37,10 +37,14 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -37,10 +37,14 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta_zero = 0.0; // const float beta_zero = 0.0;
const float beta_one = 1.0; // const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta_zero = 0.0;
const half beta_one = 1.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -106,7 +110,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -106,7 +110,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -203,7 +207,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -203,7 +207,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -231,9 +235,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -231,9 +235,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
...@@ -301,7 +308,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -301,7 +308,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -327,7 +334,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -327,7 +334,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
......
...@@ -36,10 +36,14 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -36,10 +36,14 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta_zero = 0.0; // const float beta_zero = 0.0;
const float beta_one = 1.0; // const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta_zero = 0.0;
const half beta_one = 1.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -104,7 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -104,7 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -209,7 +213,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -209,7 +213,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -237,9 +241,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -237,9 +241,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
...@@ -307,7 +314,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -307,7 +314,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -333,7 +340,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -333,7 +340,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
......
...@@ -36,9 +36,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -36,9 +36,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -102,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -102,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -205,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -205,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -233,9 +236,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -233,9 +236,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
...@@ -303,7 +309,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -303,7 +309,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -329,7 +335,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -329,7 +335,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -464,7 +470,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -464,7 +470,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -490,7 +496,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -490,7 +496,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
......
...@@ -40,9 +40,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -40,9 +40,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -124,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -124,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
output_lin_dim, output_lin_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -228,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -228,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -280,9 +283,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -280,9 +283,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; // const float alpha = 1.0;
const float beta = 0.0; // const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); // const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
...@@ -361,7 +367,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -361,7 +367,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -387,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -387,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -523,7 +529,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -523,7 +529,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_lin_grads.data_ptr()), static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -550,7 +556,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -550,7 +556,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
......
...@@ -42,9 +42,48 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -42,9 +42,48 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
} }
} }
// void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
// cublasOperation_t opa = convertTransToCublasOperation(transa);
// cublasOperation_t opb = convertTransToCublasOperation(transb);
// cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
// cublasSetStream(handle, stream);
// float fAlpha = alpha;
// float fBeta = beta;
// //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
// opa, opb, (int)m, (int)n, (int)k,
// (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
// b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
// (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
// d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
// (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
// }
// void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
// auto stream = c10::cuda::getCurrentCUDAStream();
// if ( (transa == 't') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 't') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else {
// AT_ASSERTM(false, "TransA and TransB are invalid");
// }
// }
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) { half beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
cublasOperation_t opa = convertTransToCublasOperation(transa); cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb); cublasOperation_t opb = convertTransToCublasOperation(transb);
...@@ -56,16 +95,16 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, ...@@ -56,16 +95,16 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA, (void*)&alpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB, b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC, (void*)&beta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD, d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags)); (int)batchCount, rocblas_datatype_f16_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
} }
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) { half beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
auto stream = c10::cuda::getCurrentCUDAStream(); auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) { if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
......
...@@ -151,6 +151,7 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -151,6 +151,7 @@ class EncdecMultiheadAttn(nn.Module):
self.dropout, self.dropout,
) )
if is_training: if is_training:
print('default:', outputs)
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else: else:
outputs = outputs + query outputs = outputs + query
......
...@@ -164,6 +164,8 @@ cublasStatus_t gemm_bias( ...@@ -164,6 +164,8 @@ cublasStatus_t gemm_bias(
at::Half* C, at::Half* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
half hAlpha = __float2half(*alpha);
half hBeta = __float2half(*beta);
return rocblas_gemm_ex( return rocblas_gemm_ex(
handle, handle,
transa, transa,
...@@ -171,21 +173,21 @@ cublasStatus_t gemm_bias( ...@@ -171,21 +173,21 @@ cublasStatus_t gemm_bias(
m, m,
n, n,
k, k,
alpha, /* alpha */ &hAlpha,
A, A,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
lda, lda,
B, B,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldb, ldb,
beta, /* beta */ &hBeta,
C, C,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldc, ldc,
C, C,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldc, ldc,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0); 0);
......
...@@ -211,6 +211,8 @@ cublasStatus_t mlp_gemm( ...@@ -211,6 +211,8 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
half hAlpha = __float2half(*alpha);
half hBeta = __float2half(*beta);
return rocblas_gemm_ex( return rocblas_gemm_ex(
handle, handle,
transa, transa,
...@@ -218,21 +220,21 @@ cublasStatus_t mlp_gemm( ...@@ -218,21 +220,21 @@ cublasStatus_t mlp_gemm(
m, m,
n, n,
k, k,
alpha, /* alpha */ &hAlpha,
A, A,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
lda, lda,
B, B,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldb, ldb,
beta, /* beta */ &hBeta,
C, C,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldc, ldc,
C, C,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldc, ldc,
rocblas_datatype_f32_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag); flag);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment