Commit 038ed999 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Fix some compiling errors

parent bbf2c8d0
......@@ -261,7 +261,7 @@ cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean,
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U *buf = shared.getPointer();
U mu, sigma2;
......@@ -475,7 +475,7 @@ cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid,
const T *__restrict__ input, const int n1, const int n2,
const U *__restrict__ mean, const U *__restrict__ invvar,
U epsilon, const T *gamma, T *grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
......
......@@ -19,18 +19,13 @@ namespace multihead_attn {
namespace self_bias_additive_mask {
namespace rocblas_gemmex {
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
)
{
std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
int heads, torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask, float dropout_prob) {
const int embed_dim = inputs.size(2);
const int sequences = inputs.size(1);
const int q_seq_len = inputs.size(0);
......@@ -49,6 +44,32 @@ namespace rocblas_gemmex {
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
// by ATen library code)
auto act_options = inputs.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results =
torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor bmm1_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results =
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask =
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results =
torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
torch::Tensor outputs = torch::empty_like(inputs, act_options);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void *q_lin_results_ptr = static_cast<void *>(input_lin_results.data_ptr());
void *k_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + head_dim);
void *v_lin_results_ptr = static_cast<void *>(
static_cast<half *>(input_lin_results.data_ptr()) + 2 * head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
......
......@@ -931,28 +931,6 @@ void HostApplyRMSNorm(
output, invvar, input, n1, n2, U(epsilon), gamma, warp_size);
}
template<typename T, typename U, typename V=T>
void HostApplyRMSNorm(
V* output,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
cuApplyRMSNorm<<<blocks, threads, nshared, stream>>>(
output, invvar, input, n1, n2, U(epsilon), gamma);
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment