Commit 227be6be authored by root's avatar root Committed by flyingdown
Browse files

revert multihead_attn to fp32_r

parent 412a8ac5
...@@ -42,13 +42,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -42,13 +42,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int batch_stride_q = head_dim; const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim; const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; const float alpha = 1.0;
const float beta = 0.0; const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -285,9 +281,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -285,9 +281,6 @@ std::vector<torch::Tensor> bwd_cuda(
const float alpha = 1.0; const float alpha = 1.0;
const float beta = 0.0; const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
...@@ -390,176 +383,51 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -390,176 +383,51 @@ std::vector<torch::Tensor> bwd_cuda(
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
// Output Linear Wgrad // MatMul2 Dgrad1
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, gemm_switch_fp32accum( a_layout_t,
hipOperationToRocOperation(CUBLAS_OP_N), b_layout_n,
hipOperationToRocOperation(CUBLAS_OP_T), k_seq_len,
embed_dim, q_seq_len,
embed_dim, head_dim,
batches_q, alpha,
static_cast<const void*>(&h_alpha), static_cast<const half*>(v_lin_results_ptr),
static_cast<const void*>(matmul2_results.data_ptr()), lead_dim_kv,
rocblas_datatype_f16_r, batch_stride_kv,
embed_dim, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const void*>(output_grads.data_ptr()), head_dim*attn_batches,
rocblas_datatype_f16_r, head_dim,
embed_dim, beta,
static_cast<const void*>(&h_beta), static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), k_seq_len,
rocblas_datatype_f16_r, k_seq_len*q_seq_len,
embed_dim, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), k_seq_len,
rocblas_datatype_f16_r, k_seq_len*q_seq_len,
embed_dim, attn_batches,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, flags);
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
} else {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad // Matmul2 Dgrad2
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, gemm_switch_fp32accum( a_layout_n,
hipOperationToRocOperation(CUBLAS_OP_N), b_layout_t,
hipOperationToRocOperation(CUBLAS_OP_T), head_dim,
embed_dim, k_seq_len,
embed_dim, q_seq_len,
batches_q, alpha,
static_cast<const void*>(&alpha), static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), head_dim*attn_batches,
rocblas_datatype_f16_r, head_dim,
embed_dim, static_cast<const half*>(dropout_results.data_ptr()),
static_cast<const void*>(output_grads.data_ptr()), k_seq_len,
rocblas_datatype_f16_r, k_seq_len*q_seq_len,
embed_dim, beta,
static_cast<const void*>(&beta), v_lin_grads_ptr,
static_cast<void*>(output_weight_grads.data_ptr()), lead_dim_kv,
rocblas_datatype_f16_r, batch_stride_kv,
embed_dim, v_lin_grads_ptr,
static_cast<void*>(output_weight_grads.data_ptr()), lead_dim_kv,
rocblas_datatype_f16_r, batch_stride_kv,
embed_dim, attn_batches,
rocblas_datatype_f32_r, flags);
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
}
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
......
...@@ -51,13 +51,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -51,13 +51,9 @@ std::vector<torch::Tensor> fwd_cuda(
const int batch_stride_q = head_dim; const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim; const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0;
const float alpha = 1.0; const float beta = 0.0;
const float beta = 0.0; const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -337,9 +333,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -337,9 +333,6 @@ std::vector<torch::Tensor> bwd_cuda(
const float alpha = 1.0; const float alpha = 1.0;
const float beta = 0.0; const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
...@@ -454,177 +447,51 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -454,177 +447,51 @@ std::vector<torch::Tensor> bwd_cuda(
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
// Output Linear Wgrad // MatMul2 Dgrad1
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, gemm_switch_fp32accum( a_layout_t,
hipOperationToRocOperation(CUBLAS_OP_N), b_layout_n,
hipOperationToRocOperation(CUBLAS_OP_T), k_seq_len,
embed_dim, q_seq_len,
embed_dim, head_dim,
batches_q, alpha,
static_cast<const void*>(&h_alpha), static_cast<const half*>(v_lin_results_ptr),
static_cast<const void*>(matmul2_results.data_ptr()), lead_dim_kv,
rocblas_datatype_f16_r /*a_type*/, batch_stride_kv,
embed_dim, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const void*>(dropout_add_grads.data_ptr()), head_dim*attn_batches,
rocblas_datatype_f16_r /*b_type*/, head_dim,
embed_dim, beta,
static_cast<const void*>(&h_beta), static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), k_seq_len,
rocblas_datatype_f16_r /*c_type*/, k_seq_len*q_seq_len,
embed_dim, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), k_seq_len,
rocblas_datatype_f16_r /*d_type*/, k_seq_len*q_seq_len,
embed_dim, attn_batches,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, flags);
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
} else {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad // Matmul2 Dgrad2
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, gemm_switch_fp32accum( a_layout_n,
hipOperationToRocOperation(CUBLAS_OP_N), b_layout_t,
hipOperationToRocOperation(CUBLAS_OP_T), head_dim,
embed_dim, k_seq_len,
embed_dim, q_seq_len,
batches_q, alpha,
static_cast<const void*>(&alpha), static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), head_dim*attn_batches,
rocblas_datatype_f16_r /*a_type*/, head_dim,
embed_dim, static_cast<const half*>(dropout_results.data_ptr()),
static_cast<const void*>(dropout_add_grads.data_ptr()), k_seq_len,
rocblas_datatype_f16_r /*b_type*/, k_seq_len*q_seq_len,
embed_dim, beta,
static_cast<const void*>(&beta), v_lin_grads_ptr,
static_cast<void*>(output_weight_grads.data_ptr()), lead_dim_kv,
rocblas_datatype_f16_r /*c_type*/, batch_stride_kv,
embed_dim, v_lin_grads_ptr,
static_cast<void*>(output_weight_grads.data_ptr()), lead_dim_kv,
rocblas_datatype_f16_r /*d_type*/, batch_stride_kv,
embed_dim, attn_batches,
rocblas_datatype_f32_r /*compute_type*/, flags);
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
}
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
......
...@@ -37,14 +37,10 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -37,14 +37,10 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
const float alpha = 1.0; const float alpha = 1.0;
const float beta_zero = 0.0; const float beta_zero = 0.0;
const float beta_one = 1.0; const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta_zero = 0.0;
const half h_beta_one = 1.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -238,9 +234,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -238,9 +234,6 @@ std::vector<torch::Tensor> bwd_cuda(
const float alpha = 1.0; const float alpha = 1.0;
const float beta = 0.0; const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
......
...@@ -40,10 +40,6 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -40,10 +40,6 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
const float beta_zero = 0.0; const float beta_zero = 0.0;
const float beta_one = 1.0; const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta_zero = 0.0;
const half h_beta_one = 1.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -244,9 +240,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -244,9 +240,6 @@ std::vector<torch::Tensor> bwd_cuda(
const float alpha = 1.0; const float alpha = 1.0;
const float beta = 0.0; const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
......
...@@ -36,12 +36,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -36,12 +36,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0; const float alpha = 1.0;
// const float beta = 0.0; const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -105,7 +102,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -105,7 +102,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
output_lin_dim, output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -208,7 +205,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -208,7 +205,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -239,9 +236,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -239,9 +236,6 @@ std::vector<torch::Tensor> bwd_cuda(
const float alpha = 1.0; const float alpha = 1.0;
const float beta = 0.0; const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
......
...@@ -43,9 +43,6 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -43,9 +43,6 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const float alpha = 1.0; const float alpha = 1.0;
const float beta = 0.0; const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is // There is no reason to use more than one stream as every kernel is
// sequentially dependent // sequentially dependent
...@@ -286,9 +283,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -286,9 +283,6 @@ std::vector<torch::Tensor> bwd_cuda(
const float alpha = 1.0; const float alpha = 1.0;
const float beta = 0.0; const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
//#include <ATen/ATen.h> //#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "utils.h"
//#include "cutlass/cutlass.h" //#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h" //#include "cutlass/gemm/gemm.h"
...@@ -29,8 +28,6 @@ int32_t solution_index = 0; ...@@ -29,8 +28,6 @@ int32_t solution_index = 0;
rocblas_int flags = 0; rocblas_int flags = 0;
*/ */
static bool use_fp16 = parseEnvVarFlag("APEX_APEX_ROCBLAS_GEMM_ALLOW_HALF");
namespace { namespace {
cublasOperation_t convertTransToCublasOperation(char trans) { cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') if (trans == 't')
...@@ -84,45 +81,6 @@ void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, ...@@ -84,45 +81,6 @@ void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
} }
} }
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
half beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
float fAlpha = alpha;
float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&alpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&beta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f16_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
}
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
half beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else {
AT_ASSERTM(false, "TransA and TransB are invalid");
}
}
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k,
int64_t *lda, int64_t *ldb, int64_t *ldc) { int64_t *lda, int64_t *ldb, int64_t *ldc) {
int transa_ = ((transa == 't') || (transa == 'T')); int transa_ = ((transa == 't') || (transa == 'T'));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment