Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
__all__ = ["LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"]
__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
]
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .scaled_softmax import FusedScaleMaskSoftmax
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import FusedScaleMaskSoftmax
......@@ -17,8 +17,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction,
const float weight_decay);
const int bias_correction, const float weight_decay,
const float div_scale);
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
......
......@@ -28,7 +28,7 @@ struct AdamFunctor {
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta1_correction,
const float beta2_correction, const float epsilon, const float lr,
adamMode_t mode, const float decay) {
adamMode_t mode, const float decay, const float div_scale) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
......@@ -79,6 +79,8 @@ struct AdamFunctor {
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (div_scale > 0) r_g[ii] /= div_scale;
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
......@@ -116,8 +118,8 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction,
const float weight_decay) {
const int bias_correction, const float weight_decay,
const float div_scale) {
using namespace at;
// Handle bias correction mode
......@@ -133,7 +135,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);)
lr, (adamMode_t)mode, weight_decay, div_scale);)
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -2,8 +2,14 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/Types.hpp>
#else
#include <c10d/Types.hpp>
#endif
#include <iostream>
#include "context.h"
......
......@@ -4,8 +4,15 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#else
#include <c10d/ProcessGroup.hpp>
#endif
#include <string>
#include <type_traits>
......
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
"""
import math
import os
import subprocess
import torch
def triton_cuda_check():
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip()
cuda_version = cuda_version.split('release ')[1]
cuda_version = cuda_version.split(',')[0]
cuda_version = cuda_version.split('.')
if len(cuda_version) == 2 and \
(int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \
int(cuda_version[0]) > 11:
return True
return False
try:
import triton
import triton.language as tl
if triton_cuda_check():
HAS_TRITON = True
else:
print("triton requires cuda >= 11.4")
HAS_TRITON = False
except ImportError:
print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try:
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_func,
flash_attn_unpadded_kvpacked_func,
flash_attn_unpadded_qkvpacked_func,
)
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
try:
from xformers.ops.fmha import memory_efficient_attention
HAS_MEM_EFF_ATTN = True
except ImportError:
HAS_MEM_EFF_ATTN = False
print('please install xformers from https://github.com/facebookresearch/xformers')
if HAS_TRITON:
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
TMP,
L,
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out,
DO,
L,
NewDO,
Delta,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
M,
D,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
Z,
H,
N_CTX,
num_block,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _TritonFlashAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
tmp,
L,
m,
o,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
o,
do,
l,
do_scaled,
delta,
BLOCK_M=ctx.BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)](
q,
k,
v,
ctx.sm_scale,
o,
do_scaled,
dq,
dk,
dv,
l,
m,
delta,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK,
BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
)
return dq, dk, dv, None
def triton_flash_attention(q, k, v, sm_scale):
"""
Arguments:
q: (batch, nheads, seq, headdim)
k: (batch, nheads, seq, headdim)
v: (batch, nheads, seq, headdim)
sm_scale: float. The scaling of QK^T before applying softmax.
Return:
out: (batch, nheads, seq, headdim)
"""
if HAS_TRITON:
return _TritonFlashAttention.apply(q, k, v, sm_scale)
else:
raise RuntimeError("Triton kernel requires CUDA 11.4+!")
if HAS_FLASH_ATTN:
from einops import rearrange
class MaskedFlashAttention(torch.nn.Module):
def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None:
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = attention_head_size
self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size),
attention_dropout=attention_dropout)
def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False):
if attention_mask.dtype is not torch.bool:
attention_mask = attention_mask.bool()
qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads)
context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal)
context = rearrange(context, 'b s h d -> b s (h d)')
return context
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
"""
Arguments:
qkv: (batch * seqlen, 3, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
dropout_p: float.
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
max_s = seq_len
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(qkv,
cu_seqlens,
max_s,
dropout_p,
softmax_scale=sm_scale,
causal=causal)
return out
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
q: (batch * q_seqlen, nheads, headdim)
kv: (batch * kv_seqlen, 2, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
dropout_p: float.
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen,
step=kv_seqlen,
dtype=torch.int32,
device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p,
sm_scale, causal)
return out
def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
q: (batch * q_seqlen, nheads, headdim)
k: (batch * kv_seqlen, nheads, headdim)
v: (batch * kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen,
step=kv_seqlen,
dtype=torch.int32,
device=k.device)
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
causal)
if HAS_MEM_EFF_ATTN:
from einops import rearrange
from xformers.ops.fmha import LowerTriangularMask
class MemoryEfficientAttention(torch.nn.Module):
def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0):
super().__init__()
attention_head_size = hidden_size // num_attention_heads
self.scale = 1 / attention_head_size**0.5
self.dropout = attention_dropout
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor):
context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale)
context = rearrange(context, 'b s h d -> b s (h d)')
return context
......@@ -3,14 +3,11 @@
with some changes. """
import numbers
import torch
from torch.nn.parameter import Parameter
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn import init
from torch.cuda.amp import custom_fwd, custom_bwd
import importlib
global colossal_layer_norm_cuda
colossal_layer_norm_cuda = None
from torch.nn.parameter import Parameter
class FusedLayerNormAffineFunction(torch.autograd.Function):
......@@ -18,14 +15,18 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, weight, bias, normalized_shape, eps):
try:
from colossalai._C import layer_norm
except ImportError:
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
layer_norm = LayerNormBuilder().load()
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
ctx.eps)
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
......@@ -33,11 +34,16 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
try:
from colossalai._C import layer_norm
except ImportError:
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
layer_norm = LayerNormBuilder().load()
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= colossal_layer_norm_cuda.backward_affine(
= layer_norm.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
......@@ -50,13 +56,6 @@ class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
super(MixedFusedLayerNorm, self).__init__()
global colossal_layer_norm_cuda
if colossal_layer_norm_cuda is None:
try:
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
except ImportError:
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
......
import math
import importlib
from dataclasses import dataclass
import torch
......@@ -50,8 +49,8 @@ class Config:
class MultiHeadAttention1DFunc(Function):
@staticmethod
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
out_proj_bias, norm_weight, norm_bias, config):
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight,
norm_bias, config):
cuda_module = colossal_multihead_attention
forward_func = (cuda_module.multihead_attention_fw_fp16
if config.fp16 else cuda_module.multihead_attention_fw_fp32)
......@@ -59,13 +58,12 @@ class MultiHeadAttention1DFunc(Function):
input = input.to(torch.half)
input_mask = input_mask.to(torch.half)
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias,
out_proj_weight, out_proj_bias, norm_weight, norm_bias,
config.training, config.norm_first)
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first)
if config.is_grad_enabled and config.training:
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias,
out_proj_weight, out_proj_bias, norm_weight, norm_bias)
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
out_proj_bias, norm_weight, norm_bias)
ctx.config = config
return output
......@@ -98,8 +96,8 @@ class MultiHeadAttention1DFunc(Function):
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None)
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias,
grad_norm_weight, grad_norm_bias, None)
class MultiHeadAttention(nn.Module):
......@@ -121,19 +119,11 @@ class MultiHeadAttention(nn.Module):
layer_id = 0
def __init__(self,
hidden_size,
nhead,
batch_size,
max_seq_len,
dropout=0.0,
norm_first=False,
fp16=True,
pg=None):
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
super(MultiHeadAttention, self).__init__()
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout,
dropout, norm_first, fp16)
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first,
fp16)
check_config(self.config)
self.pg = pg
self.pg_size = 1
......@@ -145,10 +135,9 @@ class MultiHeadAttention(nn.Module):
# Load cuda modules if needed
global colossal_multihead_attention
if colossal_multihead_attention is None:
try:
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
except ImportError:
raise RuntimeError('MultiHeadAttention requires cuda extensions')
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load()
colossal_multihead_attention = multihead_attention
# create the layer in cuda kernels.
cuda_module = colossal_multihead_attention
......@@ -215,14 +204,13 @@ class MultiHeadAttention(nn.Module):
with torch.no_grad():
self.in_proj_weight.copy_(
attn_qkvw_global.view(3, hs, hs)[
:, int(hs * rank_in_pg / self.pg_size):
int(hs * (rank_in_pg + 1) / self.pg_size),
:])
attn_qkvw_global.view(3, hs, hs)[:,
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size), :])
self.in_proj_bias.copy_(
attn_qkvb_global.view(3, hs)[
:, int(hs * rank_in_pg / self.pg_size):
int(hs * (rank_in_pg + 1) / self.pg_size)])
attn_qkvb_global.view(3, hs)[:,
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size)])
attn_ow_global = torch.empty(hs, hs)
nn.init.xavier_uniform_(attn_ow_global, 1.0)
......@@ -230,9 +218,9 @@ class MultiHeadAttention(nn.Module):
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
attn_ow_global = attn_ow_global.cpu()
with torch.no_grad():
self.out_proj_weight.copy_(attn_ow_global[
:, int(hs * rank_in_pg / self.pg_size):
int(hs * (rank_in_pg + 1) / self.pg_size)])
self.out_proj_weight.copy_(attn_ow_global[:,
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)])
else:
attn_qkvw = self.in_proj_weight.view(-1, hs)
......@@ -243,10 +231,7 @@ class MultiHeadAttention(nn.Module):
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
def state_dict(self, destination=None, prefix="", keep_vars=False):
destination = torch.nn.Module.state_dict(self,
destination=destination,
prefix=prefix,
keep_vars=keep_vars)
destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
return destination
def forward(self, hidden_states, encoder_padding_mask):
......@@ -257,8 +242,7 @@ class MultiHeadAttention(nn.Module):
bs, sl, dim = hidden_states.size()
if bs * sl > self.config.max_batch_tokens:
raise ValueError(
f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
if sl > self.config.max_seq_len:
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
if len(encoder_padding_mask.size()) == 1:
......@@ -266,9 +250,8 @@ class MultiHeadAttention(nn.Module):
else:
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask,
self.in_proj_weight, self.in_proj_bias,
self.out_proj_weight, self.out_proj_bias,
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight,
self.in_proj_bias, self.out_proj_weight, self.out_proj_bias,
self.norm_weight, self.norm_bias, self.config)
return output.to(self.precision)
"""This code from NVIDIA Megatron
with some changes. """
import enum
import torch
import torch.nn as nn
import enum
class AttnMaskType(enum.Enum):
......@@ -22,26 +23,20 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, scale):
try:
import colossal_scaled_upper_triang_masked_softmax
except ImportError:
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
from colossalai.kernel import scaled_upper_triang_masked_softmax
scale_t = torch.tensor([scale])
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
try:
import colossal_scaled_upper_triang_masked_softmax
except ImportError:
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
from colossalai.kernel import scaled_upper_triang_masked_softmax
softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
......@@ -58,26 +53,28 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
try:
import colossal_scaled_masked_softmax
from colossalai._C import scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
scale_t = torch.tensor([scale])
softmax_results = colossal_scaled_masked_softmax.forward(inputs, mask, scale_t[0])
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
try:
import colossal_scaled_masked_softmax
from colossalai._C import scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
......@@ -184,8 +181,8 @@ class FusedScaleMaskSoftmax(nn.Module):
@staticmethod
def get_batch_per_block(sq, sk, b, np):
try:
import colossal_scaled_masked_softmax
import colossalai._C.scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
import torch
from colossalai.nn.layer.colossalai_layer import Embedding, Linear
from colossalai.utils import get_current_device
from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
JIT_OPTIONS_SET = False
......@@ -30,3 +36,44 @@ def set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_gpu(True)
JIT_OPTIONS_SET = True
def warmup_jit_fusion(batch_size: int,
hidden_size: int,
seq_length: int = 512,
vocab_size: int = 32768,
dtype: torch.dtype = torch.float32):
""" Compilie JIT functions before the main training steps """
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device())
x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device())
x = embed(x)
y, y_bias = linear_1(x)
z, z_bias = linear_2(y)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
for _ in range(10):
bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device())
input_ = torch.rand_like(y, dtype=dtype, device=get_current_device())
bias.requires_grad, input_.requires_grad = bias_grad, input_grad
bias_gelu_impl(input_, bias)
# Warmup fused bias+dropout+add
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
for _ in range(10):
input_ = torch.rand_like(z, dtype=dtype, device=get_current_device())
residual = torch.rand_like(x, dtype=dtype, device=get_current_device())
bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device())
input_.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
bias_dropout_add_fused_train(input_, bias, residual, dropout_rate)
torch.cuda.empty_cache()
../../op_builder
\ No newline at end of file
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import colossalai
import inspect
import logging
from pathlib import Path
from typing import Union, List
import inspect
from typing import List, Union
import colossalai
from colossalai.context.parallel_mode import ParallelMode
try:
from rich.logging import RichHandler
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO,
format=_FORMAT,
handlers=[RichHandler(show_path=False, markup=True, rich_tracebacks=True)])
except ImportError:
_FORMAT = 'colossalai - %(name)s - %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT)
class DistributedLogger:
"""This is a distributed event logger class essentially based on :class:`logging`.
......@@ -55,8 +45,23 @@ class DistributedLogger:
raise Exception(
'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger')
else:
handler = None
formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s')
try:
from rich.logging import RichHandler
handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
handler.setFormatter(formatter)
except ImportError:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
self._name = name
self._logger = logging.getLogger(name)
self._logger.setLevel(logging.INFO)
if handler is not None:
self._logger.addHandler(handler)
self._logger.propagate = False
DistributedLogger.__instances[name] = self
@staticmethod
......@@ -119,7 +124,7 @@ class DistributedLogger:
# add file handler
file_handler = logging.FileHandler(path, mode)
file_handler.setLevel(getattr(logging, level))
formatter = logging.Formatter(_FORMAT)
formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s')
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
......
from ._ops import *
from .layer import *
from .loss import *
from .lr_scheduler import *
from .metric import *
from .optimizer import *
from ._ops import *
from .linear import colo_linear
from .addmm import colo_addmm
from .batch_norm import colo_batch_norm
from .element_wise import *
from .layernorm import colo_layernorm
from .loss import colo_cross_entropy
from .embedding import colo_embedding
from .addmm import colo_addmm
from .embedding_bag import colo_embedding_bag
from .layernorm import colo_layernorm
from .linear import colo_linear
from .loss import colo_cross_entropy
from .view import colo_view
......@@ -55,7 +55,7 @@ def colo_addmm(input_tensor: GeneralTensor,
mat2: ColoTensor,
beta: Number = 1,
alpha: Number = 1,
*args) -> ColoTensor:
**kargs) -> ColoTensor:
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
......@@ -70,7 +70,7 @@ def colo_addmm(input_tensor: GeneralTensor,
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha),
tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs),
spec=ColoTensorSpec(mat2.get_process_group()))
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.is_shard_1drow() and input_tensor.is_replicate():
......
from typing import Optional
import torch.nn.functional as F
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
@colo_op_impl(F.batch_norm)
def colo_batch_norm(
input: GeneralTensor,
running_mean: Optional[GeneralTensor],
running_var: Optional[GeneralTensor],
weight: Optional[GeneralTensor] = None,
bias: Optional[GeneralTensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
):
assert isinstance(weight, ColoTensor)
running_mean = running_mean.detach()
running_var = running_var.detach()
input = convert_to_colo_tensor(input, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group())
input = input.redistribute(ReplicaSpec())
bias = bias.redistribute(ReplicaSpec())
output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group()))
return output
import torch
import torch.nn.functional as F
from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, ColoTensorSpec
from ._utils import GeneralTensor
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
def register_elementwise_op(op):
......@@ -15,8 +17,13 @@ def register_elementwise_op(op):
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
if 'inplace' in kwargs:
# TODO(jiaruifang) inplace will cause bugs
input_tensor = input_tensor.clone()
return op(input_tensor, *args, **kwargs)
else:
output = op(input_tensor, *args, **kwargs)
# return output
if isinstance(input_tensor, ColoTensor):
if isinstance(output, str):
return output
......@@ -27,6 +34,16 @@ def register_elementwise_op(op):
dist_attr=input_tensor.dist_spec))
# @colo_op_impl(torch.relu_)
# def elementwise_op(input_tensor):
# torch.relu_(input_tensor.data)
# return input_tensor
# @colo_op_impl(Tensor.add_)
# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
# input_tensor = input_tensor.data.add_(*args, **kwargs)
# return input_tensor
# Tensor op
register_elementwise_op(Tensor.abs)
register_elementwise_op(Tensor.absolute)
......
import torch.nn.functional as F
from copy import deepcopy
from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor
import torch.nn.functional as F
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input
def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
......@@ -162,10 +164,8 @@ def _has_sharding_spec(tensor):
@colo_op_impl(F.linear)
def colo_linear(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
if _has_sharding_spec(weight):
return _new_colo_linear_imp(input_tensor, weight, bias)
return _new_colo_linear_imp(input, weight, bias)
else:
return colo_linear_imp(input_tensor, weight, bias)
return colo_linear_imp(input, weight, bias)
from .utils import register_colo_graph
from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode
__all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode']
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment