Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8bdbb502
Commit
8bdbb502
authored
Oct 28, 2021
by
hubertlu-tw
Browse files
Hipify encdec_multihead_attn
parent
ba0e5fa5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
427 additions
and
67 deletions
+427
-67
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
...contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+423
-63
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
View file @
8bdbb502
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
encdec
{
namespace
encdec
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemm
_
ex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
...
@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
);
);
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemm
_
ex
}
// end namespace encdec
}
// end namespace encdec
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec
::
c
u
blas_gemmex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec
::
ro
cblas_gemm
_
ex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec
::
c
u
blas_gemmex
::
bwd
,
"Encdec Multihead Attention Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec
::
ro
cblas_gemm
_
ex
::
bwd
,
"Encdec Multihead Attention Backward."
);
}
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
8bdbb502
...
@@ -85,10 +85,12 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -85,10 +85,12 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
// TODO (OK)
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
// Input Linear Q Fwd
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_q_dim
,
output_lin_q_dim
,
...
@@ -96,20 +98,45 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -96,20 +98,45 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
CUDA_R_32F
,
q_lin_results_ptr
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
output_lin_q_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_q_dim,
// batches_q,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_q_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Fwd
// Input Linear KV Fwd
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
output_lin_kv_dim
,
...
@@ -117,19 +144,44 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -117,19 +144,44 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
k_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_kv_dim,
// batches_kv,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// k_lin_results_ptr,
// CUDA_R_16F,
// output_lin_kv_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -146,8 +198,29 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -146,8 +198,29 @@ std::vector<torch::Tensor> fwd_cuda(
beta
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim_kv,
// batch_stride_kv,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim_q,
// batch_stride_q,
// beta,
// static_cast<half*>(softmax_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -191,6 +264,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -191,6 +264,7 @@ std::vector<torch::Tensor> fwd_cuda(
}
}
// Matmul2
// Matmul2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -208,10 +282,32 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -208,10 +282,32 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim_kv,
// batch_stride_kv,
// (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
// Output Linear
// Output Linear
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -219,20 +315,45 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -219,20 +315,45 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
algo
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches_q,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_lin_q_results
,
input_lin_q_results
,
...
@@ -311,11 +432,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -311,11 +432,12 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
// TODO (OK)
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
//
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
// Output Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -323,20 +445,45 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -323,20 +445,45 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches_q,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
// Output Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -344,19 +491,44 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -344,19 +491,44 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches_q,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -374,9 +546,31 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -374,9 +546,31 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim_kv,
// batch_stride_kv,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -394,7 +588,28 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -394,7 +588,28 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim_kv,
// batch_stride_kv,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
@@ -416,6 +631,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -416,6 +631,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -433,9 +649,31 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -433,9 +649,31 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim_kv,
// batch_stride_kv,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim_q,
// batch_stride_q,
// attn_batches);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -453,10 +691,32 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -453,10 +691,32 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim_q,
// batch_stride_q,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim_kv,
// batch_stride_kv,
// attn_batches);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -464,21 +724,46 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -464,21 +724,46 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
algo
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches_q,
// output_lin_q_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_q_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_q_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -486,20 +771,45 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -486,20 +771,45 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_q_dim,
// batches_q,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_q_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_q_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -507,21 +817,46 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -507,21 +817,46 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
rocblas_datatype_f16_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches_kv,
// output_lin_kv_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(k_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_kv_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_kv_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -529,18 +864,43 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -529,18 +864,43 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_kv_dim,
// batches_kv,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(k_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_kv_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_kv_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_q_grads
,
input_q_grads
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment