Unverified Commit 0772828f authored by binmakeswell's avatar binmakeswell Committed by GitHub
Browse files

[NFC] Hotfix/format (#984)



* [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939)

* [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938)

* [NFC] polish moe_cuda_kernel.cu code style (#940)
Co-authored-by: default avatarXiao Ye <xiaoye2@illinois.edu>

* [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943)

* [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942)

* [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945)

* [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946)
Co-authored-by: default avatarjnbai <897086360@qq.com>

* [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949)
Co-authored-by: default avatarJiatong <jiatong.han@u.nus.edu>

* [NFC] polish colossalai/builder/pipeline.py code style (#951)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953)
Co-authored-by: default avatar何晓昕 <cautious@hexiaoxins-MacBook-Pro.local>

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954)

* [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py  code style (#955)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956)
Co-authored-by: default avatarRichardoLuo <14049555596@qq.com>

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962)

* [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963)
Co-authored-by: default avatar“Arsmart123 <202476410arsmart@gmail.com>

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964)

* [NFC] polish __init__.py code style (#965)

* [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968)

code style

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970)

* [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972)

* [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977)

* [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979)

* [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980)

* [NFC] polish colossalai/nn/layer/utils/common.py code style (#983)
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
Co-authored-by: default avataryuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com>
Co-authored-by: default avatarGeng Zhang <34452939+zxgx@users.noreply.github.com>
Co-authored-by: default avatarMaruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com>
Co-authored-by: default avatarXYE <92607131+Itok2000u@users.noreply.github.com>
Co-authored-by: default avatarXiao Ye <xiaoye2@illinois.edu>
Co-authored-by: default avatarHaoyuQin <79465534+coder-chin@users.noreply.github.com>
Co-authored-by: default avatarwky <64853922+wangkuangyi@users.noreply.github.com>
Co-authored-by: default avatarbajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com>
Co-authored-by: default avatarluoling-LC <105470086+luoling-LC@users.noreply.github.com>
Co-authored-by: default avatarjnbai <897086360@qq.com>
Co-authored-by: default avatarJT.Han <59948448+JThh@users.noreply.github.com>
Co-authored-by: default avatarJiatong <jiatong.han@u.nus.edu>
Co-authored-by: default avatarxyupeng <99191637+xyupeng@users.noreply.github.com>
Co-authored-by: default avatarSze-qq <68757353+Sze-qq@users.noreply.github.com>
Co-authored-by: default avatarCautiousss <48676630+Cautiousss@users.noreply.github.com>
Co-authored-by: default avatar何晓昕 <cautious@hexiaoxins-MacBook-Pro.local>
Co-authored-by: default avatarLuxios22 <67457897+Luxios22@users.noreply.github.com>
Co-authored-by: default avatarWangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com>
Co-authored-by: default avatarRichardoLuo <50363844+RichardoLuo@users.noreply.github.com>
Co-authored-by: default avatarRichardoLuo <14049555596@qq.com>
Co-authored-by: default avatardoubleHU <98150031+huxin711@users.noreply.github.com>
Co-authored-by: default avatarrunluo <68489000+run-qiao@users.noreply.github.com>
Co-authored-by: default avatarMaxT <854721132@qq.com>
Co-authored-by: default avatarsuperhao1995 <804673818@qq.com>
Co-authored-by: default avatarziyu huang <huang0ziyu@gmail.com>
Co-authored-by: default avatar“Arsmart123 <202476410arsmart@gmail.com>
Co-authored-by: default avatarYuer867 <62204893+Yuer867@users.noreply.github.com>
Co-authored-by: default avatarlucasliunju <lucasliunju@gmail.com>
Co-authored-by: default avatarLuGY <74758262+Gy-Lu@users.noreply.github.com>
Co-authored-by: default avatarExtremeViscent <zhangyiqi55732@sina.com>
Co-authored-by: default avatarXu Kai <xukai16@foxmail.com>
Co-authored-by: default avatarZirui Zhu <zhuzr21@gmail.com>
Co-authored-by: default avatarOfey Chan <ofey206@gmail.com>
Co-authored-by: default avatarDouJS <dujiangsu@163.com>
Co-authored-by: default avatarJie Zhu <chore.08-protist@icloud.com>
Co-authored-by: default avatarshenggan <csg19971016@gmail.com>
Co-authored-by: default avatarKai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com>
Co-authored-by: default avatarpuck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: default avatarZiheng Qin <37519855+henryqin1997@users.noreply.github.com>
parent 5898ccf3
......@@ -15,7 +15,8 @@
#define BLOCK_SIZE 512
#define ILP 4
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
......@@ -28,24 +29,25 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
}
typedef enum {
MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
} adamMode_t;
std::tuple<at::Tensor, at::Tensor>
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
using MATH_T = float;
template <typename T> struct LAMBStage1Functor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta3,
const float beta1_correction, const float beta2_correction,
const float epsilon, adamMode_t mode, const float decay,
const float *global_grad_norm, const float max_global_grad_norm) {
template <typename T>
struct LAMBStage1Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta3,
const float beta1_correction, const float beta2_correction,
const float epsilon, adamMode_t mode, const float decay,
const float *global_grad_norm, const float max_global_grad_norm) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
......@@ -89,8 +91,7 @@ template <typename T> struct LAMBStage1Functor {
i_start += blockDim.x) {
// load
load_store(l_g, g, 0, i_start);
if (decay != 0)
load_store(l_p, p, 0, i_start);
if (decay != 0) load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
......@@ -204,12 +205,12 @@ template <typename T> struct LAMBStage1Functor {
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template <typename T> struct LAMBStage2Functor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm,
const float *per_tensor_update_norm, const float learning_rate,
const float decay, bool use_nvlamb) {
template <typename T>
struct LAMBStage2Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm, const float *per_tensor_update_norm,
const float learning_rate, const float decay, bool use_nvlamb) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
......@@ -310,8 +311,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1)
beta3 = 1 - beta1;
if (grad_averaging == 1) beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
tensor_lists.begin() + 1);
......@@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1, bias_correction2, epsilon,
(adamMode_t)mode, weight_decay,
global_grad_norm.DATA_PTR<float>(), max_grad_norm);)
......
......@@ -15,7 +15,8 @@
#define BLOCK_SIZE 512
#define ILP 4
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
......@@ -27,7 +28,8 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
}
template <typename in_t, typename out_t> struct ScaleFunctor {
template <typename in_t, typename out_t>
struct ScaleFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int *noop_gmem,
TensorListMetadata<2> &tl,
......@@ -76,8 +78,7 @@ template <typename in_t, typename out_t> struct ScaleFunctor {
for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size)
r_in[ii] = in[i];
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point
......@@ -93,14 +94,13 @@ template <typename in_t, typename out_t> struct ScaleFunctor {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size)
out[i] = r_out[ii];
if (i < n && i < chunk_size) out[i] = r_out[ii];
}
}
}
if (!finite)
*noop_gmem =
1; // Blindly fire off a write. These will race but that's ok.
1; // Blindly fire off a write. These will race but that's ok.
}
};
......
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include "compat.h"
#include <assert.h>
#include <cuda_runtime.h>
#include "compat.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
......@@ -28,69 +29,53 @@
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
template <int N, typename T_grad, typename T_weight>
struct SGDFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int *noop_gmem,
TensorListMetadata<N> &tl,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
// Early exit if we don't need to do anything
if (*noop_gmem)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
grad_in += chunk_idx * chunk_size;
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
weight_in += chunk_idx * chunk_size;
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
mom_in += chunk_idx * chunk_size;
at::Half *model_weights_out = nullptr;
if (N == 4)
{
model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx * chunk_size;
}
struct SGDFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<N> &tl,
float wd, float momentum, float dampening, float lr, bool nesterov,
bool first_run, bool wd_after_momentum, float scale) {
// Early exit if we don't need to do anything
if (*noop_gmem) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
grad_in += chunk_idx * chunk_size;
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
weight_in += chunk_idx * chunk_size;
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
mom_in += chunk_idx * chunk_size;
at::Half *model_weights_out = nullptr;
if (N == 4) {
model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx * chunk_size;
}
n -= chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for (int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP)
{
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++)
{
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size)
{
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
for (int ii = 0; ii < ILP; ii++) {
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
......@@ -98,185 +83,128 @@ struct SGDFunctor
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size)
{
// apply weight decay before momentum if necessary
if (wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
if (momentum != 0.f)
{
if (!first_run)
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
if (nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii];
else
incoming_grads[ii] = incoming_moms[ii];
}
// Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if (N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
// also write out the new momentum
if (momentum != 0.f)
mom_in[i] = incoming_moms[ii];
}
}
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
// apply weight decay before momentum if necessary
if (wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
if (momentum != 0.f) {
if (!first_run)
incoming_moms[ii] = incoming_moms[ii] * momentum +
(1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
if (nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii];
else
incoming_grads[ii] = incoming_moms[ii];
}
// Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if (N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
// also write out the new momentum
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
}
}
}
}
};
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if (num_tensors == 4)
for (int i = 0; i < tensor_lists[3].size(); i++)
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, at::Half, at::Half>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd, float momentum, float dampening, float lr,
bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if (num_tensors == 4)
for (int i = 0; i < tensor_lists[3].size(); i++)
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half && num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, at::Half, at::Half>(), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum,
scale);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float && num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum,
scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float && num_tensors == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, at::Half, float>(), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum,
scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float && num_tensors == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum,
scale);
} else {
AT_ERROR(
"multi_tensor_sgd only supports some combinations of gradient & weight "
"types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type,
", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
}
\ No newline at end of file
......@@ -10,8 +10,9 @@
#include "kernels.h"
template <typename T>
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens, int max_seq_len,
int hidden_size, int num_heads,
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens,
int max_seq_len, int hidden_size,
int num_heads,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm)
......@@ -22,18 +23,22 @@ MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens, in
_heads(num_heads),
_training(true),
_pre_or_postLayerNorm(pre_or_postLayerNorm),
_qkv_linear(typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
_attn_out_linear(typename FeedForward<T>::Config(hidden_size, hidden_size)),
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false), _max_batch_tokens),
_qkv_linear(
typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
_attn_out_linear(
typename FeedForward<T>::Config(hidden_size, hidden_size)),
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false),
_max_batch_tokens),
_softmax(typename Softmax<T>::Config(num_heads)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio),
_max_batch_tokens * _heads * _max_seq_len),
_attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio),
_max_batch_tokens * _hidden_size),
_attn_scores(typename StridedBatchGemm<T>::Config((T(1.0) / T(sqrt(_hidden_size / _heads))),
T(0.0), CUBLAS_OP_T, CUBLAS_OP_N)),
_attn_context(
typename StridedBatchGemm<T>::Config(T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
_attn_scores(typename StridedBatchGemm<T>::Config(
(T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T,
CUBLAS_OP_N)),
_attn_context(typename StridedBatchGemm<T>::Config(
T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
assert(_hidden_size % _heads == 0);
}
......@@ -43,43 +48,52 @@ MultiHeadAttention<T>::~MultiHeadAttention() {
}
template <typename T>
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr, const T *input_mask_ptr,
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr,
const T *input_mask_ptr,
T *output_ptr, T *buffer) {
T *q_tf_ptr = _qkv_ptr;
T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
if (_pre_or_postLayerNorm) {
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens,
_stream);
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr,
_batch_tokens, _stream);
}
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, _cublasHandle);
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer,
_cublasHandle);
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr, _batch_size, _seq_len, 3,
_heads / pg_size, _hidden_size / _heads, _stream);
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr,
_batch_size, _seq_len, 3, _heads / pg_size,
_hidden_size / _heads, _stream);
// attention scores, q*k
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr,
_cublasHandle);
// Softmax + Mask
_softmax.reset_size(_heads / pg_size);
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, _seq_len, _stream, true);
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len,
_seq_len, _stream, true);
// attn prob dropout.
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, _batch_heads * _seq_len * _seq_len,
_stream);
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr,
_batch_heads * _seq_len * _seq_len, _stream);
// attention context, score * v
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle);
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr,
_cublasHandle);
// [b, nh, s, ad] -> [b, s, nh, ad]
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, _hidden_size / pg_size,
_heads / pg_size, 1, _stream);
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len,
_hidden_size / pg_size, _heads / pg_size, 1,
_stream);
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, output_ptr, _cublasHandle);
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr,
output_ptr, _cublasHandle);
// allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
......@@ -88,24 +102,27 @@ void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr, const T *input_mas
if (typeid(T) != typeid(float)) {
data_type = torch::kHalf;
}
auto output_tensor =
torch::from_blob(output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type));
auto output_tensor = torch::from_blob(
output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {output_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
work->wait();
}
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, _attn_ob_ptr,
_batch_tokens, _hidden_size, _stream);
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr,
_attn_ob_ptr, _batch_tokens, _hidden_size,
_stream);
if (!_pre_or_postLayerNorm) {
// in-place ln since ln-input will not be used in post-ln mode
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, _stream);
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr,
_batch_tokens, _stream);
}
}
template <typename T>
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr) {
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
T *out_ptr) {
_stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle();
T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim
......@@ -114,8 +131,11 @@ void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
}
template <typename T>
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
const T *grad_output_ptr, T *grad_input_ptr, T *buffer) {
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
const T *input_mask_ptr,
const T *output_ptr,
const T *grad_output_ptr,
T *grad_input_ptr, T *buffer) {
cudaStream_t streams[2] = {_stream, _stream};
const T *q_tf_ptr = _qkv_ptr;
......@@ -137,45 +157,57 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
// batch_size * head_num * seq_len * seq_len);
if (_pre_or_postLayerNorm) {
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_output_ptr,
_batch_tokens, _hidden_size, _stream);
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
grad_output_ptr, _batch_tokens,
_hidden_size, _stream);
} else {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, grad_output_ptr,
nullptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_residual_ptr,
_batch_tokens, _hidden_size, _stream);
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr,
grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr,
_attn_nb_ptr, _batch_tokens, streams);
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
grad_residual_ptr, _batch_tokens,
_hidden_size, _stream);
}
// bw of output project
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, _attn_ow_ptr,
_grad_attn_ow_ptr, _grad_attn_ob_ptr, _cublasHandle, _stream,
grad_input_buf_ptr, nullptr, false);
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size, _seq_len,
_hidden_size / pg_size, _heads / pg_size, _stream);
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr,
_attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr,
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
false);
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size,
_seq_len, _hidden_size / pg_size, _heads / pg_size,
_stream);
// bw of score * v
_attn_context.Backward(_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
_attn_context.Backward(
_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
_attn_prob_dropout.d_dropout(grad_softmax_ptr, _batch_heads * _seq_len * _seq_len, _stream);
_attn_prob_dropout.d_dropout(grad_softmax_ptr,
_batch_heads * _seq_len * _seq_len, _stream);
_softmax.reset_size(_heads / pg_size);
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, _seq_len, _stream);
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len,
_seq_len, _stream);
// bw of q * k
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle,
grad_qkv_5d_ptr + _batch_dim / pg_size, grad_qkv_5d_ptr);
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr,
_cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size,
grad_qkv_5d_ptr);
// [3, b, nh, s, ad] -> [b, s, 3, h]
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, _seq_len,
_hidden_size / pg_size, _heads / pg_size, 3, _stream);
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size,
_seq_len, _hidden_size / pg_size, _heads / pg_size,
3, _stream);
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, _attn_qkvw_ptr,
_grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, _cublasHandle, _stream,
grad_input_buf_ptr, nullptr, true);
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr,
_attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr,
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
true);
// allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
......@@ -185,7 +217,8 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
data_type = torch::kHalf;
}
auto grad_input_tensor =
torch::from_blob(grad_input_buf_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::from_blob(grad_input_buf_ptr,
{int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
......@@ -193,19 +226,21 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
}
if (_pre_or_postLayerNorm) {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, grad_input_buf_ptr,
grad_output_ptr, gemmQKV_inp_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens,
streams);
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr,
grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr,
_attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
} else {
// FIXME later
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, _batch_size,
_seq_len, _hidden_size, _stream);
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr,
_batch_size, _seq_len, _hidden_size, _stream);
}
}
template <typename T>
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
const T *input_mask_ptr, T *grad_input_ptr) {
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr,
const T *input_ptr, const T *output_ptr,
const T *input_mask_ptr,
T *grad_input_ptr) {
_stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle();
T *buffer = _shared_mem_ptr;
......@@ -215,7 +250,8 @@ void MultiHeadAttention<T>::Backward(const T *grad_output_ptr, const T *input_pt
4 * _batch_dim + max(3 * _batch_dim,
_batch_size * _head_num * _seq_len * _seq_len);
*/
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, grad_input_ptr, buffer);
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr,
grad_input_ptr, buffer);
}
template <typename T>
......@@ -233,7 +269,8 @@ template class MultiHeadAttention<__half>;
// x is torch::Tensor
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
......@@ -241,15 +278,17 @@ template class MultiHeadAttention<__half>;
static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention;
template <typename T>
int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_len, int hidden_dim,
int num_heads, float attn_prob_dropout_ratio,
float hidden_dropout_ratio, bool pre_or_postLayerNorm,
int create_multihead_attention(int layer_id, int max_batch_tokens,
int max_seq_len, int hidden_dim, int num_heads,
float attn_prob_dropout_ratio,
float hidden_dropout_ratio,
bool pre_or_postLayerNorm,
c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Context::Instance().set_stream(stream);
auto layer = std::make_shared<MultiHeadAttention<T>>(
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, attn_prob_dropout_ratio,
hidden_dropout_ratio, pre_or_postLayerNorm);
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads,
attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm);
layer->SetPG(pg_);
......@@ -261,15 +300,12 @@ int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_l
}
template <typename T>
std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Tensor &input,
const torch::Tensor &input_mask,
const torch::Tensor &in_proj_weight,
const torch::Tensor &in_proj_bias,
const torch::Tensor &out_proj_weight,
const torch::Tensor &out_proj_bias,
const torch::Tensor &norm_weight,
const torch::Tensor &norm_bias,
bool training_mode, bool prelayernorm) {
std::vector<torch::Tensor> multihead_attention_fw(
int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask,
const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias,
const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias,
const torch::Tensor &norm_weight, const torch::Tensor &norm_bias,
bool training_mode, bool prelayernorm) {
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
......@@ -280,7 +316,8 @@ std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Ten
T *out_ptr = (T *)output.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]);
std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(input.size(0), input.size(1));
layer->SetTrainingMode(training_mode);
......@@ -297,17 +334,13 @@ std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Ten
}
template <typename T>
std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
const torch::Tensor &grad_dec_output,
const torch::Tensor &output,
const torch::Tensor &input,
const torch::Tensor &input_mask,
const torch::Tensor &in_proj_weight,
const torch::Tensor &in_proj_bias,
const torch::Tensor &out_proj_weight,
const torch::Tensor &out_proj_bias,
const torch::Tensor &norm_weight,
const torch::Tensor &norm_bias) {
std::vector<torch::Tensor> multihead_attention_bw(
int layer_id, const torch::Tensor &grad_dec_output,
const torch::Tensor &output, const torch::Tensor &input,
const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight,
const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight,
const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight,
const torch::Tensor &norm_bias) {
auto g_output = grad_dec_output.contiguous();
CHECK_INPUT(g_output);
CHECK_INPUT(output);
......@@ -332,7 +365,8 @@ std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
T *grad_input_ptr = (T *)grad_input.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]);
std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(g_output.size(0), g_output.size(1));
layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr();
......@@ -342,10 +376,12 @@ std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr();
layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr();
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, grad_input_ptr);
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr,
grad_input_ptr);
return {grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
grad_out_proj_bias, grad_norm_weight, grad_norm_bias};
return {grad_input, grad_in_proj_weight, grad_in_proj_bias,
grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight,
grad_norm_bias};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -19,21 +19,25 @@
template <typename T>
class MultiHeadAttention {
public:
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size,
int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio,
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
int hidden_size, int num_heads, float attn_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm);
virtual ~MultiHeadAttention();
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
const T *input_mask_ptr, T *grad_input_ptr);
void Backward(const T *grad_output_ptr, const T *input_ptr,
const T *output_ptr, const T *input_mask_ptr,
T *grad_input_ptr);
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer);
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
T *buffer);
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer);
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
const T *output_ptr, const T *grad_output_ptr,
T *grad_input_attn_layer_bwptr, T *buffer);
void set_cur_batch_shape(int batch_size, int seq_len) {
_batch_size = batch_size;
......@@ -83,14 +87,17 @@ class MultiHeadAttention {
}
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
_soft_out_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_soft_out_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
// buffer size needed by attn bw
size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size +
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
_max_batch_tokens * _heads / pg_size * _max_seq_len);
size_t smem_size =
4 * _max_batch_tokens * _hidden_size / pg_size +
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
_max_batch_tokens * _heads / pg_size * _max_seq_len);
if (!_shared_mem_ptr) {
cuda_free(_shared_mem_ptr);
......
......@@ -2,12 +2,13 @@
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
......@@ -15,17 +16,15 @@ namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
int attn_heads) {
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
// seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
......@@ -38,10 +37,10 @@ torch::Tensor fwd_cuda(
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
torch::Tensor softmax_results = torch::empty(
{batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
......@@ -49,31 +48,23 @@ torch::Tensor fwd_cuda(
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor,
query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
// seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
......@@ -81,24 +72,18 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
// backward pass is completely in-place
return output_grads;
}
}
}
}
} // namespace scaled_masked_softmax
} // namespace fused_softmax
} // namespace multihead_attn
......@@ -3,57 +3,52 @@
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
torch::Tensor bwd(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results, float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
......@@ -2,12 +2,13 @@
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
......@@ -15,18 +16,15 @@ namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
......@@ -36,50 +34,42 @@ torch::Tensor fwd_cuda(
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t,
float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr), scale_factor, seq_len,
seq_len, attn_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
// output grads is a 3d tensor with dimensions [attn_batches, seq_len,
// seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t,
float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, seq_len, seq_len, attn_batches););
// backward pass is completely in-place
return output_grads;
}
}
}
}
} // namespace scaled_upper_triang_masked_softmax
} // namespace fused_softmax
} // namespace multihead_attn
......@@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
......@@ -72,8 +72,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
def forward(self, input):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
self.normalized_shape, self.eps)
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
def __repr__(self):
return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})'
......@@ -28,9 +28,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
scale_t = torch.tensor([scale])
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(
inputs, scale_t[0]
)
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -43,9 +41,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(
output_grads, softmax_results, scale_t[0]
)
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
......@@ -81,9 +77,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_masked_softmax.backward(
output_grads, softmax_results, scale_t[0]
)
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
......@@ -114,9 +108,8 @@ class FusedScaleMaskSoftmax(nn.Module):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
assert not (self.input_in_fp16
and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
......@@ -124,9 +117,7 @@ class FusedScaleMaskSoftmax(nn.Module):
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
......@@ -140,14 +131,13 @@ class FusedScaleMaskSoftmax(nn.Module):
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if (self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
......
import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
......@@ -9,10 +8,12 @@ import torch
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
......@@ -23,9 +24,11 @@ def bias_gelu_back(g, bias, y):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
return ff * g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
......@@ -38,4 +41,5 @@ class GeLUFunction(torch.autograd.Function):
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
\ No newline at end of file
bias_gelu_impl = GeLUFunction.apply
......@@ -182,7 +182,7 @@ class Linear2D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
......@@ -337,16 +337,16 @@ class LayerNorm2D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
......@@ -569,7 +569,7 @@ class PatchEmbedding2D(ParallelLayer):
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL)
pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL)
......@@ -1012,7 +1012,7 @@ class Classifier2D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, )
out_shape = input_.shape[:-1] + (self.num_classes,)
return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
......@@ -1186,7 +1186,7 @@ class VocabParallelClassifier2D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.output_size_per_partition, )
out_shape = x.shape[:-1] + (self.output_size_per_partition,)
output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
......
......@@ -189,7 +189,7 @@ class Linear2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2p5D.apply(
x,
......@@ -254,7 +254,7 @@ class LayerNorm2p5D(ParallelLayer):
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
......@@ -357,16 +357,16 @@ class LayerNorm2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
......@@ -589,7 +589,7 @@ class PatchEmbedding2p5D(ParallelLayer):
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL)
pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL)
......@@ -1038,7 +1038,7 @@ class Classifier2p5D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, )
out_shape = input_.shape[:-1] + (self.num_classes,)
return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank,
self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
......@@ -1172,7 +1172,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_ABT_2p5D.apply(
x,
......
......@@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer):
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
if bias:
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition,
device=get_current_device(), dtype=dtype))
self.bias = Parameter(
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.variance_epsilon = eps
......@@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer):
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
......
......@@ -13,7 +13,8 @@ from torch import Tensor, nn
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True, offload : bool = False):
def __init__(self, checkpoint: bool = True, offload: bool = False):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
......@@ -78,6 +79,7 @@ def get_tensor_parallel_mode():
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment