Unverified Commit 58580b50 authored by ver217's avatar ver217 Committed by GitHub
Browse files

Revert "[NFC] Hotfix/format (#984)" (#986)

This reverts commit 0772828f.
parent 0772828f
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template <typename T> template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0; return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
} }
...@@ -29,25 +28,24 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, ...@@ -29,25 +28,24 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
} }
typedef enum { typedef enum {
MOMENT_MODE_0 = 0, // L2 regularization mode MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode MOMENT_MODE_1 = 1 // Decoupled weight decay mode
} adamMode_t; } adamMode_t;
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( std::tuple<at::Tensor, at::Tensor>
int chunk_size, at::Tensor noop_flag, multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python); at::optional<bool> per_tensor_python);
using MATH_T = float; using MATH_T = float;
template <typename T> template <typename T> struct LAMBStage1Functor {
struct LAMBStage1Functor { __device__ __forceinline__ void
__device__ __forceinline__ void operator()( operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, const float beta1, const float beta2, const float beta3,
const float beta1, const float beta2, const float beta3, const float beta1_correction, const float beta2_correction,
const float beta1_correction, const float beta2_correction, const float epsilon, adamMode_t mode, const float decay,
const float epsilon, adamMode_t mode, const float decay, const float *global_grad_norm, const float max_global_grad_norm) {
const float *global_grad_norm, const float max_global_grad_norm) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
...@@ -91,7 +89,8 @@ struct LAMBStage1Functor { ...@@ -91,7 +89,8 @@ struct LAMBStage1Functor {
i_start += blockDim.x) { i_start += blockDim.x) {
// load // load
load_store(l_g, g, 0, i_start); load_store(l_g, g, 0, i_start);
if (decay != 0) load_store(l_p, p, 0, i_start); if (decay != 0)
load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start); load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start); load_store(l_v, v, 0, i_start);
// unpack // unpack
...@@ -205,12 +204,12 @@ struct LAMBStage1Functor { ...@@ -205,12 +204,12 @@ struct LAMBStage1Functor {
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value. // It computes new parameter value.
template <typename T> template <typename T> struct LAMBStage2Functor {
struct LAMBStage2Functor { __device__ __forceinline__ void
__device__ __forceinline__ void operator()( operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, const float *per_tensor_param_norm,
const float *per_tensor_param_norm, const float *per_tensor_update_norm, const float *per_tensor_update_norm, const float learning_rate,
const float learning_rate, const float decay, bool use_nvlamb) { const float decay, bool use_nvlamb) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
...@@ -311,7 +310,8 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -311,7 +310,8 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
// Handle grad averaging mode // Handle grad averaging mode
float beta3 = 1.0f; float beta3 = 1.0f;
if (grad_averaging == 1) beta3 = 1 - beta1; if (grad_averaging == 1)
beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
tensor_lists.begin() + 1); tensor_lists.begin() + 1);
...@@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
LAMBStage1Functor<scalar_t_0>(), beta1, beta2, LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
beta3, // 1-beta1 or 1 depends on averaging mode beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1, bias_correction2, epsilon, bias_correction1, bias_correction2, epsilon,
(adamMode_t)mode, weight_decay, (adamMode_t)mode, weight_decay,
global_grad_norm.DATA_PTR<float>(), max_grad_norm);) global_grad_norm.DATA_PTR<float>(), max_grad_norm);)
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template <typename T> template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0; return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
} }
...@@ -28,8 +27,7 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, ...@@ -28,8 +27,7 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
} }
template <typename in_t, typename out_t> template <typename in_t, typename out_t> struct ScaleFunctor {
struct ScaleFunctor {
__device__ __forceinline__ void operator()(int chunk_size, __device__ __forceinline__ void operator()(int chunk_size,
volatile int *noop_gmem, volatile int *noop_gmem,
TensorListMetadata<2> &tl, TensorListMetadata<2> &tl,
...@@ -78,7 +76,8 @@ struct ScaleFunctor { ...@@ -78,7 +76,8 @@ struct ScaleFunctor {
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0; r_in[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i]; if (i < n && i < chunk_size)
r_in[ii] = in[i];
} }
// note for clarification to future michael: // note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point // From a pure memory dependency perspective, there's likely no point
...@@ -94,13 +93,14 @@ struct ScaleFunctor { ...@@ -94,13 +93,14 @@ struct ScaleFunctor {
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) out[i] = r_out[ii]; if (i < n && i < chunk_size)
out[i] = r_out[ii];
} }
} }
} }
if (!finite) if (!finite)
*noop_gmem = *noop_gmem =
1; // Blindly fire off a write. These will race but that's ok. 1; // Blindly fire off a write. These will race but that's ok.
} }
}; };
......
// modified from // modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include "compat.h"
#include <assert.h> #include <assert.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "compat.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -29,53 +28,69 @@ ...@@ -29,53 +28,69 @@
* wd_after_momentum : apply weight decay _after_ momentum instead of before * wd_after_momentum : apply weight decay _after_ momentum instead of before
**/ **/
template <int N, typename T_grad, typename T_weight> template <int N, typename T_grad, typename T_weight>
struct SGDFunctor { struct SGDFunctor
__device__ __forceinline__ void operator()( {
int chunk_size, volatile int *noop_gmem, TensorListMetadata<N> &tl, __device__ __forceinline__ void operator()(
float wd, float momentum, float dampening, float lr, bool nesterov, int chunk_size,
bool first_run, bool wd_after_momentum, float scale) { volatile int *noop_gmem,
// Early exit if we don't need to do anything TensorListMetadata<N> &tl,
if (*noop_gmem) return; float wd,
float momentum,
int tensor_loc = tl.block_to_tensor[blockIdx.x]; float dampening,
int chunk_idx = tl.block_to_chunk[blockIdx.x]; float lr,
int n = tl.sizes[tensor_loc]; bool nesterov,
bool first_run,
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; bool wd_after_momentum,
grad_in += chunk_idx * chunk_size; float scale)
{
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; // Early exit if we don't need to do anything
weight_in += chunk_idx * chunk_size; if (*noop_gmem)
return;
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
mom_in += chunk_idx * chunk_size;
at::Half *model_weights_out = nullptr;
if (N == 4) {
model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx * chunk_size;
}
n -= chunk_idx * chunk_size; int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
// Non-divergent exit condition for the __syncthreads T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
float incoming_grads[ILP]; grad_in += chunk_idx * chunk_size;
float incoming_weights[ILP];
float incoming_moms[ILP]; T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
for (int i_start = 0; i_start < n && i_start < chunk_size; weight_in += chunk_idx * chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
for (int ii = 0; ii < ILP; ii++) { mom_in += chunk_idx * chunk_size;
incoming_grads[ii] = 0;
incoming_weights[ii] = 0; at::Half *model_weights_out = nullptr;
incoming_moms[ii] = 0; if (N == 4)
int i = i_start + threadIdx.x + ii * blockDim.x; {
if (i < n && i < chunk_size) { model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale; model_weights_out += chunk_idx * chunk_size;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
} }
}
n -= chunk_idx * chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for (int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP)
{
#pragma unroll
for (int ii = 0; ii < ILP; ii++)
{
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size)
{
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
// note for clarification to future michael: // note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling // From a pure memory dependency perspective, there's likely no point unrolling
...@@ -83,128 +98,185 @@ struct SGDFunctor { ...@@ -83,128 +98,185 @@ struct SGDFunctor {
// Put another way, the STGs are dependent on the LDGs, but not on each other. // Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though. // There is still compute ILP benefit from unrolling the loop though.
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++)
int i = i_start + threadIdx.x + ii * blockDim.x; {
if (i < n && i < chunk_size) { int i = i_start + threadIdx.x + ii * blockDim.x;
// apply weight decay before momentum if necessary if (i < n && i < chunk_size)
if (wd != 0.f && !wd_after_momentum) {
incoming_grads[ii] += wd * incoming_weights[ii]; // apply weight decay before momentum if necessary
if (wd != 0.f && !wd_after_momentum)
if (momentum != 0.f) { incoming_grads[ii] += wd * incoming_weights[ii];
if (!first_run)
incoming_moms[ii] = incoming_moms[ii] * momentum + if (momentum != 0.f)
(1.f - dampening) * incoming_grads[ii]; {
else // initialize momentums to current incoming grads if (!first_run)
incoming_moms[ii] = incoming_grads[ii]; incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads
if (nesterov) incoming_moms[ii] = incoming_grads[ii];
incoming_grads[ii] += momentum * incoming_moms[ii];
else if (nesterov)
incoming_grads[ii] = incoming_moms[ii]; incoming_grads[ii] += momentum * incoming_moms[ii];
} else
incoming_grads[ii] = incoming_moms[ii];
// Apply WD after momentum if desired }
if (wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii]; // Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum)
// adjust the weight and write out incoming_grads[ii] += wd * incoming_weights[ii];
weight_in[i] += (-lr * incoming_grads[ii]);
// adjust the weight and write out
// if necessary, write out an fp16 copy of the weights weight_in[i] += (-lr * incoming_grads[ii]);
if (N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]); // if necessary, write out an fp16 copy of the weights
if (N == 4)
// also write out the new momentum model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
// also write out the new momentum
if (momentum != 0.f)
mom_in[i] = incoming_moms[ii];
}
}
} }
}
} }
}
}; };
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_sgd_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, int chunk_size,
float wd, float momentum, float dampening, float lr, at::Tensor noop_flag,
bool nesterov, bool first_run, std::vector<std::vector<at::Tensor>> tensor_lists,
bool wd_after_momentum, float scale) { float wd,
auto num_tensors = tensor_lists.size(); float momentum,
auto grad_type = tensor_lists[0][0].scalar_type(); float dampening,
auto weight_type = tensor_lists[1][0].scalar_type(); float lr,
bool nesterov,
if (num_tensors == 4) bool first_run,
for (int i = 0; i < tensor_lists[3].size(); i++) bool wd_after_momentum,
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, float scale)
"Additional output tensors should always be fp16."); {
auto num_tensors = tensor_lists.size();
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), auto grad_type = tensor_lists[0][0].scalar_type();
"expected noop flag to be on the same device as tensors"); auto weight_type = tensor_lists[1][0].scalar_type();
// We have 3 possibilities to handle here, in terms of if (num_tensors == 4)
// grad_type, param_type, momentum_type, requires_fp16_copy for (int i = 0; i < tensor_lists[3].size(); i++)
// 1. fp16, fp16, fp16, No TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
// 2. fp32, fp32, fp32, No "Additional output tensors should always be fp16.");
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where // We have 3 possibilities to handle here, in terms of
// we don't want the majority of them. // grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// Case 1. fp16, fp16, fp16, No // 2. fp32, fp32, fp32, No
if (grad_type == at::ScalarType::Half && // 3. fp16, fp32, fp32, Yes
weight_type == at::ScalarType::Half && num_tensors == 3) { // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, // It's easier to hardcode these possibilities than to use
SGDFunctor<3, at::Half, at::Half>(), wd, momentum, // switches etc. to handle the cross-product of cases where
dampening, lr, nesterov, first_run, wd_after_momentum, // we don't want the majority of them.
scale);
} // Case 1. fp16, fp16, fp16, No
// Case 2. fp16, fp32, fp32, No if (grad_type == at::ScalarType::Half &&
// else if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float && num_tensors == 3)
// num_tensors == 3) { {
// multi_tensor_apply<3>( multi_tensor_apply<3>(
// BLOCK_SIZE, BLOCK_SIZE,
// chunk_size, chunk_size,
// noop_flag, noop_flag,
// tensor_lists, tensor_lists,
// SGDFunctor<3, at::Half, float>(), SGDFunctor<3, at::Half, at::Half>(),
// wd, wd,
// momentum, momentum,
// dampening, dampening,
// lr, lr,
// nesterov, nesterov,
// first_run, first_run,
// wd_after_momentum); wd_after_momentum,
// } scale);
// Case 2. fp32, fp32, fp32, No }
else if (grad_type == at::ScalarType::Float && // Case 2. fp16, fp32, fp32, No
weight_type == at::ScalarType::Float && num_tensors == 3) { // else if (grad_type == at::ScalarType::Half &&
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, // weight_type == at::ScalarType::Float &&
SGDFunctor<3, float, float>(), wd, momentum, // num_tensors == 3) {
dampening, lr, nesterov, first_run, wd_after_momentum, // multi_tensor_apply<3>(
scale); // BLOCK_SIZE,
} // chunk_size,
// Case 3. fp16, fp32, fp32, Yes // noop_flag,
else if (grad_type == at::ScalarType::Half && // tensor_lists,
weight_type == at::ScalarType::Float && num_tensors == 4) { // SGDFunctor<3, at::Half, float>(),
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, // wd,
SGDFunctor<4, at::Half, float>(), wd, momentum, // momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, // dampening,
scale); // lr,
} // nesterov,
// Case 4. fp32, fp32, fp32, Yes // first_run,
else if (grad_type == at::ScalarType::Float && // wd_after_momentum);
weight_type == at::ScalarType::Float && num_tensors == 4) { // }
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, // Case 2. fp32, fp32, fp32, No
SGDFunctor<4, float, float>(), wd, momentum, else if (grad_type == at::ScalarType::Float &&
dampening, lr, nesterov, first_run, wd_after_momentum, weight_type == at::ScalarType::Float &&
scale); num_tensors == 3)
} else { {
AT_ERROR( multi_tensor_apply<3>(
"multi_tensor_sgd only supports some combinations of gradient & weight " BLOCK_SIZE,
"types. Given: ", chunk_size,
"gradient: ", grad_type, ", weight: ", weight_type, noop_flag,
", num_lists: ", num_tensors); tensor_lists,
} SGDFunctor<3, float, float>(),
wd,
AT_CUDA_CHECK(cudaGetLastError()); momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
} }
\ No newline at end of file
...@@ -19,25 +19,21 @@ ...@@ -19,25 +19,21 @@
template <typename T> template <typename T>
class MultiHeadAttention { class MultiHeadAttention {
public: public:
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size,
int hidden_size, int num_heads, float attn_dropout_ratio, int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm); bool pre_or_postLayerNorm);
virtual ~MultiHeadAttention(); virtual ~MultiHeadAttention();
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
void Backward(const T *grad_output_ptr, const T *input_ptr, void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
const T *output_ptr, const T *input_mask_ptr, const T *input_mask_ptr, T *grad_input_ptr);
T *grad_input_ptr);
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer);
T *buffer);
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
const T *output_ptr, const T *grad_output_ptr, const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer);
T *grad_input_attn_layer_bwptr, T *buffer);
void set_cur_batch_shape(int batch_size, int seq_len) { void set_cur_batch_shape(int batch_size, int seq_len) {
_batch_size = batch_size; _batch_size = batch_size;
...@@ -87,17 +83,14 @@ class MultiHeadAttention { ...@@ -87,17 +83,14 @@ class MultiHeadAttention {
} }
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3); _qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
_soft_out_ptr = _soft_out_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len); _ctx_bufB_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size); _attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
// buffer size needed by attn bw // buffer size needed by attn bw
size_t smem_size = size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size +
4 * _max_batch_tokens * _hidden_size / pg_size + std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
std::max(3 * _max_batch_tokens * _hidden_size / pg_size, _max_batch_tokens * _heads / pg_size * _max_seq_len);
_max_batch_tokens * _heads / pg_size * _max_seq_len);
if (!_shared_mem_ptr) { if (!_shared_mem_ptr) {
cuda_free(_shared_mem_ptr); cuda_free(_shared_mem_ptr);
......
...@@ -2,13 +2,12 @@ ...@@ -2,13 +2,12 @@
* with minor changes. */ * with minor changes. */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <cuda_runtime.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_masked_softmax.h" #include "scaled_masked_softmax.h"
#include "type_shim.h" #include "type_shim.h"
...@@ -16,15 +15,17 @@ namespace multihead_attn { ...@@ -16,15 +15,17 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_masked_softmax { namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int attn_heads) { return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
} }
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor) { torch::Tensor fwd_cuda(
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, torch::Tensor const& input,
// seq_len] torch::Tensor const& mask,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0); const int batches = input.size(0);
const int pad_batches = mask.size(0); const int pad_batches = mask.size(0);
const int attn_heads = input.size(1); const int attn_heads = input.size(1);
...@@ -37,10 +38,10 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, ...@@ -37,10 +38,10 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::empty( torch::Tensor softmax_results =
{batches, attn_heads, query_seq_len, key_seq_len}, act_options); torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr // Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr()); void* input_ptr = static_cast<void*>(input.data_ptr());
...@@ -48,23 +49,31 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, ...@@ -48,23 +49,31 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(), "dispatch_scaled_masked_softmax_forward", input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>( dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr), reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr), reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor, reinterpret_cast<const uint8_t*>(mask_ptr),
query_seq_len, key_seq_len, batches, attn_heads, pad_batches);); scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results; return softmax_results;
} }
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor bwd_cuda(
torch::Tensor const& softmax_results_, torch::Tensor const& output_grads_,
float scale_factor) { torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous(); auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
// seq_len]
const int batches = output_grads.size(0); const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1); const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2); const int query_seq_len = output_grads.size(2);
...@@ -72,18 +81,24 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, ...@@ -72,18 +81,24 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad //Softmax Grad
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>( dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, query_seq_len, key_seq_len, batches, attn_heads);); scale_factor,
query_seq_len,
// backward pass is completely in-place key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} // namespace scaled_masked_softmax }
} // namespace fused_softmax }
} // namespace multihead_attn }
...@@ -3,52 +3,57 @@ ...@@ -3,52 +3,57 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax { namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor bwd_cuda(
torch::Tensor const& softmax_results, torch::Tensor const& output_grads,
float scale_factor); torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported"); "Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor); return fwd_cuda(input, scale_factor);
} }
torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor bwd(
torch::Tensor const& softmax_results, float scale_factor) { torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16), (output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported"); "Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16), (softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported"); "Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
} // end namespace scaled_upper_triang_masked_softmax } // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax } // end namespace fused_softmax
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward."); "Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward", m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward."); "Self Multihead Attention scaled, time masked softmax -- Backward.");
} }
...@@ -2,13 +2,12 @@ ...@@ -2,13 +2,12 @@
* with minor changes. */ * with minor changes. */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <cuda_runtime.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h" #include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h" #include "type_shim.h"
...@@ -16,15 +15,18 @@ namespace multihead_attn { ...@@ -16,15 +15,18 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax { namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0); const int attn_batches = input.size(0);
const int seq_len = input.size(1); const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048); TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options); torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr // Softmax Intermediate Result Ptr
...@@ -34,42 +36,50 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { ...@@ -34,42 +36,50 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(), input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward", "dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
float>( reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<scalar_t*>(softmax_results_ptr), reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const scalar_t*>(input_ptr), scale_factor, seq_len, scale_factor,
seq_len, attn_batches);); seq_len,
seq_len,
attn_batches);
);
return softmax_results; return softmax_results;
} }
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor bwd_cuda(
torch::Tensor const& softmax_results_, torch::Tensor const& output_grads_,
float scale_factor) { torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous(); auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
// output grads is a 3d tensor with dimensions [attn_batches, seq_len, //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
// seq_len]
const int attn_batches = output_grads.size(0); const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1); const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad //Softmax Grad
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(), output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward", "dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
float>( reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), scale_factor,
scale_factor, seq_len, seq_len, attn_batches);); seq_len,
seq_len,
// backward pass is completely in-place attn_batches);
);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} // namespace scaled_upper_triang_masked_softmax }
} // namespace fused_softmax }
} // namespace multihead_attn }
...@@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_ = input.contiguous() input_ = input.contiguous()
weight_ = weight.contiguous() weight_ = weight.contiguous()
bias_ = bias.contiguous() bias_ = bias.contiguous()
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_, output, mean, invvar = colossal_layer_norm_cuda.forward_affine(
ctx.eps) input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output return output
...@@ -72,7 +72,8 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -72,7 +72,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
def forward(self, input): def forward(self, input):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
def __repr__(self): def __repr__(self):
return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})'
...@@ -28,7 +28,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -28,7 +28,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
...@@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) input_grads = colossal_scaled_upper_triang_masked_softmax.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None return input_grads, None
...@@ -77,7 +81,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -77,7 +81,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) input_grads = colossal_scaled_masked_softmax.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None return input_grads, None, None
...@@ -108,8 +114,9 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -108,8 +114,9 @@ class FusedScaleMaskSoftmax(nn.Module):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16 self.input_in_bf16 = input_in_bf16
assert not (self.input_in_fp16 assert not (
and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
...@@ -117,7 +124,9 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -117,7 +124,9 @@ class FusedScaleMaskSoftmax(nn.Module):
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale self.scale = scale
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, sq, sk] # [b, np, sq, sk]
...@@ -131,13 +140,14 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -131,13 +140,14 @@ class FusedScaleMaskSoftmax(nn.Module):
def is_kernel_available(self, mask, b, np, sq, sk): def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np attn_batches = b * np
if (self.scaled_masked_softmax_fusion # user want to fuse if (
and self.input_in_float16 # input must be fp16 self.scaled_masked_softmax_fusion # user want to fuse
and mask is not None # mask tensor must not be None and self.input_in_float16 # input must be fp16
and 16 < sk <= 2048 # sk must be 16 ~ 2048 and mask is not None # mask tensor must not be None
and sq % 4 == 0 # sq must be divisor of 4 and 16 < sk <= 2048 # sk must be 16 ~ 2048
and attn_batches % 4 == 0 # np * b must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
): and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048: if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np) batch_per_block = self.get_batch_per_block(sq, sk, b, np)
......
import torch import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################ ###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678 # 1/sqrt(2) -> 0.70710678
...@@ -8,12 +9,10 @@ import torch ...@@ -8,12 +9,10 @@ import torch
# actual gelu is: # actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script @torch.jit.script
def bias_gelu(bias, y): def bias_gelu(bias, y):
x = bias + y x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
...@@ -24,11 +23,9 @@ def bias_gelu_back(g, bias, y): ...@@ -24,11 +23,9 @@ def bias_gelu_back(g, bias, y):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * g return ff*g
class GeLUFunction(torch.autograd.Function): class GeLUFunction(torch.autograd.Function):
@staticmethod @staticmethod
# bias is an optional argument # bias is an optional argument
def forward(ctx, input, bias): def forward(ctx, input, bias):
...@@ -41,5 +38,4 @@ class GeLUFunction(torch.autograd.Function): ...@@ -41,5 +38,4 @@ class GeLUFunction(torch.autograd.Function):
tmp = bias_gelu_back(grad_output, bias, input) tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
bias_gelu_impl = GeLUFunction.apply \ No newline at end of file
...@@ -182,7 +182,7 @@ class Linear2D(ParallelLayer): ...@@ -182,7 +182,7 @@ class Linear2D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q] # input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q] # output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
...@@ -337,16 +337,16 @@ class LayerNorm2D(ParallelLayer): ...@@ -337,16 +337,16 @@ class LayerNorm2D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
with torch.no_grad(): with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
E_x /= self.normalized_shape E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2 # Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
Var_x /= self.normalized_shape Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon) # this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
...@@ -569,7 +569,7 @@ class PatchEmbedding2D(ParallelLayer): ...@@ -569,7 +569,7 @@ class PatchEmbedding2D(ParallelLayer):
output = F.conv2d(input_, weight, bias, stride=self.patch_size) output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL)
pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL)
...@@ -1012,7 +1012,7 @@ class Classifier2D(ParallelLayer): ...@@ -1012,7 +1012,7 @@ class Classifier2D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes,) out_shape = input_.shape[:-1] + (self.num_classes, )
return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
...@@ -1186,7 +1186,7 @@ class VocabParallelClassifier2D(ParallelLayer): ...@@ -1186,7 +1186,7 @@ class VocabParallelClassifier2D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q] # input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q] # output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.output_size_per_partition,) out_shape = x.shape[:-1] + (self.output_size_per_partition, )
output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
......
...@@ -189,7 +189,7 @@ class Linear2p5D(ParallelLayer): ...@@ -189,7 +189,7 @@ class Linear2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q] # input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q] # output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
output = Matmul_AB_2p5D.apply( output = Matmul_AB_2p5D.apply(
x, x,
...@@ -254,7 +254,7 @@ class LayerNorm2p5D(ParallelLayer): ...@@ -254,7 +254,7 @@ class LayerNorm2p5D(ParallelLayer):
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension # partitioning dimension
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
# create parameters # create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
...@@ -357,16 +357,16 @@ class LayerNorm2p5D(ParallelLayer): ...@@ -357,16 +357,16 @@ class LayerNorm2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
with torch.no_grad(): with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
E_x /= self.normalized_shape E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2 # Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
Var_x /= self.normalized_shape Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon) # this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
...@@ -589,7 +589,7 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -589,7 +589,7 @@ class PatchEmbedding2p5D(ParallelLayer):
output = F.conv2d(input_, weight, bias, stride=self.patch_size) output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL)
pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL)
...@@ -1038,7 +1038,7 @@ class Classifier2p5D(ParallelLayer): ...@@ -1038,7 +1038,7 @@ class Classifier2p5D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes,) out_shape = input_.shape[:-1] + (self.num_classes, )
return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank,
self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
...@@ -1172,7 +1172,7 @@ class VocabParallelClassifier2p5D(ParallelLayer): ...@@ -1172,7 +1172,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q] # input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q] # output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
output = Matmul_ABT_2p5D.apply( output = Matmul_ABT_2p5D.apply(
x, x,
......
...@@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer): ...@@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition,
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) device=get_current_device(), dtype=dtype))
else: else:
self.bias = None self.bias = None
self.variance_epsilon = eps self.variance_epsilon = eps
...@@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer):
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1) cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1) output = torch.cat((cls_token, output), dim=1)
......
...@@ -13,8 +13,7 @@ from torch import Tensor, nn ...@@ -13,8 +13,7 @@ from torch import Tensor, nn
class CheckpointModule(nn.Module): class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True, offload : bool = False):
def __init__(self, checkpoint: bool = True, offload: bool = False):
super().__init__() super().__init__()
self.checkpoint = checkpoint self.checkpoint = checkpoint
self._use_checkpoint = checkpoint self._use_checkpoint = checkpoint
...@@ -79,7 +78,6 @@ def get_tensor_parallel_mode(): ...@@ -79,7 +78,6 @@ def get_tensor_parallel_mode():
def _ntuple(n): def _ntuple(n):
def parse(x): def parse(x):
if isinstance(x, collections.abc.Iterable): if isinstance(x, collections.abc.Iterable):
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment