Commit 98957dd7 authored by luopl's avatar luopl
Browse files

init

parents
Pipeline #1625 canceled with stages
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_bwd_kernel.cuh"
template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_bwd_kernel.cuh"
template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_bwd_kernel.cuh"
template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_bwd_kernel.cuh"
template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
This diff is collapsed.
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#ifndef USE_ROCM
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
#include <cuda_fp16.h>
#include <c10/util/complex.h> // For scalar_value_type
#ifndef USE_ROCM
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return std::max(ilist);
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return std::min(a, b);
}
#else
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return *std::max_element(ilist.begin(), ilist.end());
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
#endif
#define MAX_DSTATE 256
using complex_t = c10::complex<float>;
inline __device__ float2 operator+(const float2 & a, const float2 & b){
return {a.x + b.x, a.y + b.y};
}
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
return {a.x + b.x, a.y + b.y, a.z + b.z};
}
inline __device__ float4 operator+(const float4 & a, const float4 & b){
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES> struct BytesToType {};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename scalar_t, int N>
struct Converter{
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
}
};
template<int N>
struct Converter<at::Half, N>{
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
static_assert(N % 2 == 0);
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
#pragma unroll
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
}
};
#if __CUDA_ARCH__ >= 800
template<int N>
struct Converter<at::BFloat16, N>{
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
static_assert(N % 2 == 0);
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
#pragma unroll
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
__device__ __forceinline__ complex_t cexp2f(complex_t z) {
float t = exp2f(z.real_);
float c, s;
sincosf(z.imag_, &s, &c);
return complex_t(c * t, s * t);
}
__device__ __forceinline__ complex_t cexpf(complex_t z) {
float t = expf(z.real_);
float c, s;
sincosf(z.imag_, &s, &c);
return complex_t(c * t, s * t);
}
template<typename scalar_t> struct SSMScanOp;
template<>
struct SSMScanOp<float> {
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
}
};
template<>
struct SSMScanOp<complex_t> {
__device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
complex_t a0 = complex_t(ab0.x, ab0.y);
complex_t b0 = complex_t(ab0.z, ab0.w);
complex_t a1 = complex_t(ab1.x, ab1.y);
complex_t b1 = complex_t(ab1.z, ab1.w);
complex_t out_a = a1 * a0;
complex_t out_b = a1 * b0 + b1;
return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
}
};
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
scan_t running_prefix;
// Constructor
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ scan_t operator()(scan_t block_aggregate) {
scan_t old_prefix = running_prefix;
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
return old_prefix;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Ktraits>
inline __device__ void load_input(typename Ktraits::input_t *u,
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadT::TempStorage &smem_load,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
reinterpret_cast<vec_t*>(u),
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
#ifdef USE_ROCM
, Ktraits::kNThreads * Ktraits::kNLoads
#endif
);
} else {
typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
}
}
template<typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
int seqlen) {
constexpr int kNItems = Ktraits::kNItems;
if constexpr (!Ktraits::kIsComplex) {
typename Ktraits::input_t B_vals_load[kNItems];
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
reinterpret_cast<vec_t*>(Bvar),
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
);
} else {
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
}
// #pragma unroll
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
} else {
typename Ktraits::input_t B_vals_load[kNItems * 2];
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
reinterpret_cast<vec_t*>(Bvar),
reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
);
} else {
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
}
#pragma unroll
for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
}
}
template<typename Ktraits>
inline __device__ void store_output(typename Ktraits::input_t *out,
const float (&out_vals)[Ktraits::kNItems],
typename Ktraits::BlockStoreT::TempStorage &smem_store,
int seqlen) {
typename Ktraits::input_t write_vals[Ktraits::kNItems];
#pragma unroll
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
reinterpret_cast<vec_t*>(out),
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
);
} else {
typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
}
}
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_fwd_kernel.cuh"
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase &params, cudaStream_t stream);
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_fwd_kernel.cuh"
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase &params, cudaStream_t stream);
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// Split into multiple files to compile in paralell
#include "selective_scan_fwd_kernel.cuh"
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase &params, cudaStream_t stream);
\ No newline at end of file
This diff is collapsed.
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
/******************************************************************************
* Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#ifndef USE_ROCM
#include <cub/config.cuh>
#include <cuda/std/type_traits>
#else
#include <hipcub/hipcub.hpp>
// Map ::cuda::std to the standard std namespace
namespace cuda {
namespace std = ::std;
}
#endif
namespace detail
{
#if defined(_NVHPC_CUDA)
template <typename T, typename U>
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
{
// NVBug 3384810
new (ptr) T(::cuda::std::forward<U>(val));
}
#else
template <typename T,
typename U,
typename ::cuda::std::enable_if<
::cuda::std::is_trivially_copyable<T>::value,
int
>::type = 0>
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
{
*ptr = ::cuda::std::forward<U>(val);
}
template <typename T,
typename U,
typename ::cuda::std::enable_if<
!::cuda::std::is_trivially_copyable<T>::value,
int
>::type = 0>
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
{
new (ptr) T(::cuda::std::forward<U>(val));
}
#endif
} // namespace detail
import torch
import transformers
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from lm_eval.api.model import LM
from lm_eval.models.huggingface import HFLM
from lm_eval.api.registry import register_model
from lm_eval.__main__ import cli_evaluate
@register_model("mamba")
class MambaEvalWrapper(HFLM):
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
dtype=torch.float16):
LM.__init__(self)
self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.vocab_size = self.tokenizer.vocab_size
self._batch_size = int(batch_size) if batch_size is not None else 64
self._max_length = max_length
self._device = torch.device(device)
@property
def batch_size(self):
return self._batch_size
def _model_generate(self, context, max_length, stop, **generation_kwargs):
raise NotImplementedError()
if __name__ == "__main__":
cli_evaluate()
icon.png

53.8 KB

__version__ = "2.2.2"
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from typing import Optional
import torch
from torch import Tensor
from torch.distributed import ProcessGroup
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
# Raw operation, does not support autograd, but does support async
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle
# Raw operation, does not support autograd, but does support async
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
output = torch.empty(
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle
# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
return input_, handle
class AllGatherFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""
@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_gather_raw(input_, process_group)
return output
@staticmethod
def backward(ctx, grad_output: Tensor):
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
return grad_input, None
# Supports autograd, but does not support async
all_gather = AllGatherFunc.apply
class ReduceScatterFunc(torch.autograd.Function):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = reduce_scatter_raw(input_, process_group)
return output
@staticmethod
def backward(ctx, grad_output: Tensor):
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
return grad_input, None
# Supports autograd, but does not support async
reduce_scatter = ReduceScatterFunc.apply
class AllReduceFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""
@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_reduce_raw(input_, process_group)
return output
@staticmethod
def backward(ctx, grad_output: Tensor):
return grad_output, None
# Supports autograd, but does not support async
all_reduce = AllReduceFunc.apply
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared = {
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
}
for _, p in sorted(pamams_shared.items()):
with torch.no_grad():
# Broadcast needs src to be global rank, not group rank
torch.distributed.broadcast(
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
)
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
}
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
if grads:
with torch.no_grad():
coalesced = torch._utils._flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=process_group)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple = dim // multiple_of
div = multiple // world_size
mod = multiple % world_size
local_multiple = div + int(local_rank < mod)
return local_multiple * multiple_of
# Copyright (c) 2024, Tri Dao.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from einops import rearrange
from mamba_ssm.distributed.distributed_utils import (
all_gather_raw,
all_reduce,
all_reduce_raw,
reduce_scatter,
reduce_scatter_raw,
)
class ParallelLinearFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx.compute_weight_gradient = weight.requires_grad
ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
if torch.is_autocast_enabled():
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
weight = weight.contiguous()
if process_group is not None and sequence_parallel:
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
else:
ctx.save_for_backward(weight)
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
else:
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[0]:
grad_input = F.linear(grad_output, weight.t())
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else:
grad_input = None
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
if process_group is not None and sequence_parallel:
handle_x.wait()
grad_weight = torch.einsum(
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
)
else:
grad_weight = None
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None
def parallel_linear_func(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True,
):
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
class ColumnParallelLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % multiple_of:
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
multiple = out_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
super().__init__(
in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
def forward(self, x):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return parallel_linear_func(
x,
self.weight,
self.bias,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel,
)
class RowParallelLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
rank = torch.distributed.get_rank(process_group)
if in_features % multiple_of:
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
multiple = in_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
# Only rank 0 will have bias
super().__init__(
local_multiple * multiple_of,
out_features,
bias=bias and rank == 0,
device=device,
dtype=dtype,
)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
a reduce_scatter of the result.
"""
out = parallel_linear_func(x, self.weight, self.bias)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if num_embeddings % world_size != 0:
raise ValueError(
f"num_embeddings ({num_embeddings}) must be divisible by "
f"world_size ({world_size})"
)
if world_size > 1 and padding_idx is not None:
raise RuntimeError("ParallelEmbedding does not support padding_idx")
else:
world_size = 1
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
def forward(self, input: Tensor) -> Tensor:
if self.process_group is None:
return super().forward(input)
else:
rank = torch.distributed.get_rank(self.process_group)
vocab_size = self.num_embeddings
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
input = input - vocab_start_index
input[input_ids_mask] = 0
embeddings = super().forward(input)
embeddings[input_ids_mask] = 0.0
return embeddings
class ColumnParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if embedding_dim % world_size != 0:
raise ValueError(
f"embedding_dim ({embedding_dim}) must be divisible by "
f"world_size ({world_size})"
)
else:
world_size = 1
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
class ParallelEmbeddings(nn.Module):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
process_group,
padding_idx=None,
sequence_parallel=True,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.word_embeddings = VocabParallelEmbedding(
vocab_size,
embed_dim,
padding_idx=padding_idx,
process_group=process_group,
**factory_kwargs,
)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = ColumnParallelEmbedding(
max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
)
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
world_size = torch.distributed.get_world_size(self.process_group)
embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
if world_size <= 1:
embeddings = embeddings + position_embeddings
else:
partition_dim = self.position_embeddings.embedding_dim
rank = torch.distributed.get_rank(self.process_group)
embeddings[
..., rank * partition_dim : (rank + 1) * partition_dim
] += position_embeddings
if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, "b s d -> (b s) d")
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
from dataclasses import dataclass, field
@dataclass
class MambaConfig:
d_model: int = 2560
d_intermediate: int = 0
n_layer: int = 64
vocab_size: int = 50277
ssm_cfg: dict = field(default_factory=dict)
attn_layer_idx: list = field(default_factory=list)
attn_cfg: dict = field(default_factory=dict)
rms_norm: bool = True
residual_in_fp32: bool = True
fused_add_norm: bool = True
pad_vocab_size_multiple: int = 8
tie_embeddings: bool = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment