Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
83181423
Commit
83181423
authored
Oct 28, 2021
by
hubertlu-tw
Browse files
Hipify self_multihead_attn
Enable HIP floa to hald conversion
parent
61416180
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
346 additions
and
51 deletions
+346
-51
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+4
-1
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+4
-1
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+4
-1
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+330
-44
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
83181423
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda.h>
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
83181423
#include <vector>
#include <vector>
#include <math.h>
#include <math.h>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
83181423
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
View file @
83181423
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self
{
namespace
self
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemm
_
ex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd(
...
@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd(
);
);
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemm
_
ex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self
::
ro
cblas_gemm
_
ex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self
::
ro
cblas_gemm
_
ex
::
bwd
,
"Self Multihead Attention Backward."
);
}
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
83181423
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
...
@@ -77,10 +80,11 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -77,10 +80,11 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
// TODO (OK)
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
//
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
// Input Linear Fwd
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -88,19 +92,44 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -88,19 +92,44 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
CUDA_R_32F
,
q_lin_results_ptr
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -118,7 +147,28 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -118,7 +147,28 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta,
// static_cast<half*>(softmax_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -162,6 +212,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -162,6 +212,7 @@ std::vector<torch::Tensor> fwd_cuda(
}
}
// Matmul2
// Matmul2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -179,10 +230,32 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -179,10 +230,32 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
// Output Linear
// Output Linear
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -190,19 +263,43 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -190,19 +263,43 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
embed_dim
,
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_lin_results
,
input_lin_results
,
...
@@ -270,11 +367,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -270,11 +367,12 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
// TODO (OK)
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
//
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
// Output Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -282,20 +380,45 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -282,20 +380,45 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
// Output Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OOK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -303,19 +426,44 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -303,19 +426,44 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -333,9 +481,31 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -333,9 +481,31 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -353,7 +523,28 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -353,7 +523,28 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
@@ -375,6 +566,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -375,6 +566,7 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -392,9 +584,31 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -392,9 +584,31 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -411,11 +625,33 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -411,11 +625,33 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Input Linear Dgrad
// Input Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -423,20 +659,45 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -423,20 +659,45 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
// Input Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -444,18 +705,43 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -444,18 +705,43 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_grads
,
input_grads
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment