Unverified Commit 3fe10b55 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

Seryilmaz/fused dropout softmax (#985)

* fuse dropout into softmax in fprop for additive mask case
parent 6c186b3b
...@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda( ...@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda( ...@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len, stream);
//backward pass is completely in-place //backward pass is completely in-place
return output_grads; return output_grads;
} }
......
...@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda( ...@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
if (padding_mask == nullptr) { if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda( ...@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda(
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len, stream);
} else{ } else{
dispatch_masked_scale_softmax_backward_masked_out<half, half, float,false>( dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda( ...@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda(
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len, attn_batches*q_seq_len,
heads); heads, stream);
} }
//backward pass is completely in-place //backward pass is completely in-place
......
#pragma once
//Philox CUDA.
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
counter = make_uint4(0, 0, 0, 0);
counter.z = (unsigned int)(subsequence);
counter.w = (unsigned int)(subsequence >> 32);
STATE = 0;
incr_n(offset / 4);
}
__device__ inline uint4 operator()() {
if(STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
//7-round philox
for(int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
}
output = single_round(counter_, key_);
incr();
}
//return a float4 directly
//unsigned long ret;
//switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
//STATE = (STATE + 1) % 4;
return output;
}
private:
uint4 counter;
uint4 output;
uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ inline void incr() {
if (++counter.x)
return;
if (++counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a*b;
}
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);
}
...@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads, torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results, torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results, torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results, // torch::Tensor const& softmax_results,
torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results, torch::Tensor const& input_lin_results,
torch::Tensor const& inputs, torch::Tensor const& inputs,
torch::Tensor const& input_weights, torch::Tensor const& input_weights,
...@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd( ...@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(use_mask , "no mask is not supported");
if (use_mask) { if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
...@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd( ...@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd(
torch::Tensor const& output_grads, torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results, torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results, torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results, torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results, torch::Tensor const& input_lin_results,
torch::Tensor const& inputs, torch::Tensor const& inputs,
torch::Tensor const& input_weights, torch::Tensor const& input_weights,
...@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd( ...@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
...@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd( ...@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported"); AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
...@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd( ...@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd(
output_grads, output_grads,
matmul2_results, matmul2_results,
dropout_results, dropout_results,
softmax_results, bmm1_results,
pad_mask,
input_lin_results, input_lin_results,
inputs, inputs,
input_weights, input_weights,
......
...@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda(
auto mask_options = act_options.dtype(torch::kUInt8); auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
...@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda(
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim); void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());
char a_layout_t{'t'}; char a_layout_t{'t'};
char a_layout_n{'n'}; char a_layout_n{'n'};
...@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
beta_zero, beta_zero,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(bmm1_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
if (pad_mask == nullptr) { if (is_training) {
softmax_success = dispatch_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half*>(dropout_results_ptr),
reinterpret_cast<const half*>(softmax_results_ptr), (is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,
reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask,
attn_batches*q_seq_len*q_seq_len,
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences,
1.0f-dropout_prob,
stream);
} else { } else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>( softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr), reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function
reinterpret_cast<const half*>(softmax_results_ptr), reinterpret_cast<const half*>(bmm1_results_ptr),
pad_mask, pad_mask,
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
...@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches*q_seq_len/sequences); attn_batches*q_seq_len/sequences);
} }
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2 // Matmul2
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
a_layout_n, a_layout_n,
...@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const half*>(v_lin_results_ptr), static_cast<const half*>(v_lin_results_ptr),
lead_dim, lead_dim,
batch_stride, batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
beta_zero, beta_zero,
...@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda(
return { return {
input_lin_results, input_lin_results,
softmax_results, bmm1_results,
dropout_results, dropout_results,
dropout_mask, dropout_mask,
matmul2_results, matmul2_results,
...@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor const& output_grads, torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results, torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results, torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results, torch::Tensor const& bmm1_results,
torch::Tensor const& pad_mask,
torch::Tensor const& input_lin_results, torch::Tensor const& input_lin_results,
torch::Tensor const& inputs, torch::Tensor const& inputs,
torch::Tensor const& input_weights, torch::Tensor const& input_weights,
...@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
......
...@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>( dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()), reinterpret_cast<half const*>(softmax_results.data_ptr()),
...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
1.0/(1.0-dropout_prob), 1.0/(1.0-dropout_prob),
k_seq_len, k_seq_len,
k_seq_len, k_seq_len,
attn_batches*q_seq_len); attn_batches*q_seq_len, stream);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( state, gemm_switch_fp32accum( state,
......
This diff is collapsed.
...@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
mask_additive_t= torch.tensor([mask_additive])
if use_biases_t[0]: if use_biases_t[0]:
if not mask_additive: if not mask_additive:
...@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \ output_biases, \
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else: else:
input_lin_results, \ input_lin_results, \
softmax_results, \ bmm1_results, \
dropout_results, \ dropout_results, \
dropout_mask, \ dropout_mask, \
matmul2_results, \ matmul2_results, \
...@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases, \ output_biases, \
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
null_tensor, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
else: else:
...@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor, \
dropout_prob) dropout_prob)
ctx.save_for_backward(use_biases_t, \ ctx.save_for_backward(use_biases_t, \
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
softmax_results, \ softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \ input_lin_results, \
inputs, \ inputs, \
input_weights, \ input_weights, \
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t) dropout_prob_t)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
...@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) :
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
softmax_results, \ softmax_results, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \ input_lin_results, \
inputs, \ inputs, \
input_weights, \ input_weights, \
...@@ -101,6 +134,7 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -101,6 +134,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t = ctx.saved_tensors dropout_prob_t = ctx.saved_tensors
if use_biases_t[0]: if use_biases_t[0]:
if not mask_additive_t[0]:
input_grads, \ input_grads, \
input_weight_grads, \ input_weight_grads, \
output_weight_grads, \ output_weight_grads, \
...@@ -119,6 +153,26 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -119,6 +153,26 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_mask, \ dropout_mask, \
dropout_prob_t[0]) dropout_prob_t[0])
else:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias_additive_mask.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
bmm1_results, \
pad_mask, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
else: else:
input_bias_grads = None input_bias_grads = None
output_bias_grads = None output_bias_grads = None
......
import torch
import unittest
from apex.contrib.multihead_attn import SelfMultiheadAttn
class SelfMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=True,
include_norm_add=False,
separate_qkv_params=True,
mask_additive=True,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn_additive_mask(self) :
grads = torch.randn_like(self.tst_inputs)
mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda()
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=mask,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment