Commit 98957dd7 authored by luopl's avatar luopl
Browse files

init

parents
Pipeline #1625 canceled with stages
# Copyright (c) 2023, Albert Gu, Tri Dao.
import math
from functools import partial
import json
import os
import copy
from collections import namedtuple
import torch
import torch.nn as nn
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.modules.mha import MHA
from mamba_ssm.modules.mlp import GatedMLP
from mamba_ssm.modules.block import Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
def create_block(
d_model,
d_intermediate,
ssm_cfg=None,
attn_layer_idx=None,
attn_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
device=None,
dtype=None,
):
if ssm_cfg is None:
ssm_cfg = {}
if attn_layer_idx is None:
attn_layer_idx = []
if attn_cfg is None:
attn_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
if layer_idx not in attn_layer_idx:
# Create a copy of the config to modify
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
if ssm_layer not in ["Mamba1", "Mamba2"]:
raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
mixer_cls = partial(
Mamba2 if ssm_layer == "Mamba2" else Mamba,
layer_idx=layer_idx,
**ssm_cfg,
**factory_kwargs
)
else:
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
if d_intermediate == 0:
mlp_cls = nn.Identity
else:
mlp_cls = partial(
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
)
block = Block(
d_model,
mixer_cls,
mlp_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
module,
n_layer,
initializer_range=0.02, # Now only used for embedding layer.
rescale_prenorm_residual=True,
n_residuals_per_layer=1, # Change to 2 if we have MLP
):
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
class MixerModel(nn.Module):
def __init__(
self,
d_model: int,
n_layer: int,
d_intermediate: int,
vocab_size: int,
ssm_cfg=None,
attn_layer_idx=None,
attn_cfg=None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
self.fused_add_norm = fused_add_norm
if self.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
self.layers = nn.ModuleList(
[
create_block(
d_model,
d_intermediate=d_intermediate,
ssm_cfg=ssm_cfg,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
**factory_kwargs,
)
for i in range(n_layer)
]
)
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
d_model, eps=norm_epsilon, **factory_kwargs
)
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
)
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
for i, layer in enumerate(self.layers)
}
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
)
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = layer_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, RMSNorm)
)
return hidden_states
class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
config: MambaConfig,
initializer_cfg=None,
device=None,
dtype=None,
) -> None:
self.config = config
d_model = config.d_model
n_layer = config.n_layer
d_intermediate = config.d_intermediate
vocab_size = config.vocab_size
ssm_cfg = config.ssm_cfg
attn_layer_idx = config.attn_layer_idx
attn_cfg = config.attn_cfg
rms_norm = config.rms_norm
residual_in_fp32 = config.residual_in_fp32
fused_add_norm = config.fused_add_norm
pad_vocab_size_multiple = config.pad_vocab_size_multiple
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
d_intermediate=d_intermediate,
vocab_size=vocab_size,
ssm_cfg=ssm_cfg,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
rms_norm=rms_norm,
initializer_cfg=initializer_cfg,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
if self.config.tie_embeddings:
self.lm_head.weight = self.backbone.embedding.weight
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config_data = load_config_hf(pretrained_model_name)
config = MambaConfig(**config_data)
model = cls(config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
return model
def save_pretrained(self, save_directory):
"""
Minimal implementation of save_pretrained for MambaLMHeadModel.
Save the model and its configuration file to a directory.
"""
# Ensure save_directory exists
os.makedirs(save_directory, exist_ok=True)
# Save the model's state_dict
model_path = os.path.join(save_directory, 'pytorch_model.bin')
torch.save(self.state_dict(), model_path)
# Save the configuration of the model
config_path = os.path.join(save_directory, 'config.json')
with open(config_path, 'w') as f:
json.dump(self.config.__dict__, f, indent=4)
# Copyright (c) 2024, Tri Dao, Albert Gu.
from typing import Optional
import torch
from torch import nn, Tensor
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
class Block(nn.Module):
def __init__(
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.norm = norm_cls(dim)
self.mixer = mixer_cls(dim)
if mlp_cls is not nn.Identity:
self.norm2 = norm_cls(dim)
self.mlp = mlp_cls(dim)
else:
self.mlp = None
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
is_rms_norm=isinstance(self.norm, RMSNorm)
)
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
if self.mlp is not None:
if not self.fused_add_norm:
residual = hidden_states + residual
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm2.weight,
self.norm2.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm2.eps,
is_rms_norm=isinstance(self.norm2, RMSNorm)
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
# Copyright (c) 2024, Tri Dao, Albert Gu.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
except ImportError:
causal_conv1d_varlen_states = None
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
from huggingface_hub import PyTorchModelHubMixin
class Mamba2(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
d_model,
d_state=128,
d_conv=4,
conv_init=None,
expand=2,
headdim=64,
d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
ngroups=1,
A_init_range=(1, 16),
D_has_hdim=False,
rmsnorm=True,
norm_before_gate=False,
dt_min=0.001,
dt_max=0.1,
dt_init_floor=1e-4,
dt_limit=(0.0, float("inf")),
bias=False,
conv_bias=True,
# Fused kernel and sharding options
chunk_size=256,
use_mem_eff_path=True,
layer_idx=None, # Absorb kwarg for general module
process_group=None,
sequence_parallel=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.conv_init = conv_init
self.expand = expand
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.world_size = 1 if process_group is None else process_group.size()
self.local_rank = 0 if process_group is None else process_group.rank()
self.d_inner = (self.expand * self.d_model) // self.world_size
assert self.d_inner * self.world_size == self.expand * self.d_model
self.headdim = headdim
self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
assert ngroups % self.world_size == 0
self.ngroups = ngroups // self.world_size
assert self.d_ssm % self.headdim == 0
self.nheads = self.d_ssm // self.headdim
self.D_has_hdim = D_has_hdim
self.rmsnorm = rmsnorm
self.norm_before_gate = norm_before_gate
self.dt_limit = dt_limit
self.activation = "silu"
self.chunk_size = chunk_size
self.use_mem_eff_path = use_mem_eff_path
self.layer_idx = layer_idx
# Order: [z, x, B, C, dt]
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
if self.process_group is None:
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
else:
self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
process_group=self.process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs)
conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
**factory_kwargs,
)
if self.conv_init is not None:
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
self.act = nn.SiLU()
# Initialize log dt bias
dt = torch.exp(
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
dt = torch.clamp(dt, min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
A_log = torch.log(A).to(dtype=dtype)
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
self.D._no_weight_decay = True
if self.rmsnorm:
assert RMSNormGated is not None
self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
group_size=self.d_ssm // ngroups, **factory_kwargs)
if self.process_group is None:
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
else:
self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
process_group=self.process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs)
def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
"""
u: (batch, seqlen, hidden_dim) if seqlen=None.
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
split u during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
Returns: same shape as u
"""
seqlen_og = seqlen
if seqlen is None:
batch, seqlen, dim = u.shape
else:
batch_seqlen, dim = u.shape
batch = batch_seqlen // seqlen
conv_state, ssm_state = None, None
if inference_params is not None:
inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(u, conv_state, ssm_state)
return out
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
if seqlen_og is not None:
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
# If the model is loaded in fp16, without the .float() here, A might be -inf
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
if self.use_mem_eff_path and inference_params is None:
out = mamba_split_conv1d_scan_combined(
zxbcdt,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.dt_bias,
A,
D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
chunk_size=self.chunk_size,
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=None if self.D_has_hdim else self.headdim,
ngroups=self.ngroups,
norm_before_gate=self.norm_before_gate,
**dt_limit_kwargs,
)
if seqlen_og is not None:
out = rearrange(out, "b l d -> (b l) d")
if self.process_group is not None:
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
out = reduce_fn(out, self.process_group)
else:
d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
z0, x0, z, xBC, dt = torch.split(
zxbcdt,
[d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
dim=-1
)
if conv_state is not None:
if cu_seqlens is None:
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
xBC_t = rearrange(xBC, "b l d -> b d l")
conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
else:
assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
assert batch == 1, "varlen inference only supports batch dimension 1"
conv_varlen_states = causal_conv1d_varlen_states(
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
)
conv_state.copy_(conv_varlen_states)
assert self.activation in ["silu", "swish"]
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
xBC = self.act(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
else:
xBC = causal_conv1d_fn(
xBC.transpose(1, 2),
rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2)
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
y = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
dt,
A,
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
chunk_size=self.chunk_size,
D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
dt_bias=self.dt_bias,
dt_softplus=True,
seq_idx=seq_idx,
cu_seqlens=cu_seqlens,
**dt_limit_kwargs,
return_final_states=ssm_state is not None,
return_varlen_states=cu_seqlens is not None and inference_params is not None,
)
if ssm_state is not None:
y, last_state, *rest = y
if cu_seqlens is None:
ssm_state.copy_(last_state)
else:
varlen_states = rest[0]
ssm_state.copy_(varlen_states)
y = rearrange(y, "b l h p -> b l (h p)")
if self.rmsnorm:
y = self.norm(y, z)
if d_mlp > 0:
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
if seqlen_og is not None:
y = rearrange(y, "b l d -> (b l) d")
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
z0, x0, z, xBC, dt = torch.split(
zxbcdt,
[d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
dim=-1
)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = xBC
xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
xBC = xBC + self.conv1d.bias
xBC = self.act(xBC).to(dtype=dtype)
else:
xBC = causal_conv1d_update(
xBC,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
A = -torch.exp(self.A_log.float()) # (nheads,)
# SSM step
if selective_state_update is None:
assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
# Discretize A and B
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
dA = torch.exp(dt * A) # (batch, nheads)
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
if not self.rmsnorm:
y = y * self.act(z) # (B D)
else:
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
dt = repeat(dt, "b h -> b h p", p=self.headdim)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
D = repeat(self.D, "h -> h p", p=self.headdim)
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
if not self.rmsnorm:
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
y = selective_state_update(
ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
dt_bias=dt_bias, dt_softplus=True
)
y = rearrange(y, "b h p -> b (h p)")
if self.rmsnorm:
y = self.norm(y, z)
if d_mlp > 0:
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
).transpose(1, 2)
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
ssm_state = torch.zeros(
batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
)
return conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,)
conv_state = torch.zeros(
batch_size,
self.d_conv,
self.conv1d.weight.shape[0],
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
).transpose(1, 2)
ssm_state = torch.zeros(
batch_size,
self.nheads,
self.headdim,
self.d_state,
device=self.in_proj.weight.device,
dtype=self.in_proj.weight.dtype,
)
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
# Copyright (c) 2024, Tri Dao, Albert Gu.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
try:
from causal_conv1d import causal_conv1d_fn
except ImportError:
causal_conv1d_fn = None
try:
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
except ImportError:
RMSNormGated, LayerNorm = None, None
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
class Mamba2Simple(nn.Module):
def __init__(
self,
d_model,
d_state=64,
d_conv=4,
conv_init=None,
expand=2,
headdim=128,
ngroups=1,
A_init_range=(1, 16),
dt_min=0.001,
dt_max=0.1,
dt_init_floor=1e-4,
dt_limit=(0.0, float("inf")),
learnable_init_states=False,
activation="swish",
bias=False,
conv_bias=True,
# Fused kernel and sharding options
chunk_size=256,
use_mem_eff_path=True,
layer_idx=None, # Absorb kwarg for general module
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.conv_init = conv_init
self.expand = expand
self.d_inner = self.expand * self.d_model
self.headdim = headdim
self.ngroups = ngroups
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
self.dt_limit = dt_limit
self.learnable_init_states = learnable_init_states
self.activation = activation
self.chunk_size = chunk_size
self.use_mem_eff_path = use_mem_eff_path
self.layer_idx = layer_idx
# Order: [z, x, B, C, dt]
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
**factory_kwargs,
)
if self.conv_init is not None:
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
# self.conv1d.weight._no_weight_decay = True
if self.learnable_init_states:
self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
self.init_states._no_weight_decay = True
self.act = nn.SiLU()
# Initialize log dt bias
dt = torch.exp(
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
dt = torch.clamp(dt, min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
# A parameter
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
A_log = torch.log(A).to(dtype=dtype)
self.A_log = nn.Parameter(A_log)
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
self.D._no_weight_decay = True
# Extra normalization layer right before output projection
assert RMSNormGated is not None
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
def forward(self, u, seq_idx=None):
"""
u: (B, L, D)
Returns: same shape as u
"""
batch, seqlen, dim = u.shape
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
if self.use_mem_eff_path:
# Fully fused path
out = mamba_split_conv1d_scan_combined(
zxbcdt,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.eps,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=self.headdim,
ngroups=self.ngroups,
norm_before_gate=False,
initial_states=initial_states,
**dt_limit_kwargs,
)
else:
z, xBC, dt = torch.split(
zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
)
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
assert self.activation in ["silu", "swish"]
# 1D Convolution
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
xBC = self.act(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
) # (B, L, self.d_inner + 2 * ngroups * d_state)
xBC = xBC[:, :seqlen, :]
else:
xBC = causal_conv1d_fn(
x=xBC.transpose(1, 2),
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
).transpose(1, 2)
# Split into 3 main branches: X, B, C
# These correspond to V, K, Q respectively in the SSM/attention duality
x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
y = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
dt,
A,
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=seq_idx,
initial_states=initial_states,
**dt_limit_kwargs,
)
y = rearrange(y, "b l h p -> b l (h p)")
# Multiply "gate" branch and apply extra normalization layer
y = self.norm(y, z)
out = self.out_proj(y)
return out
This diff is collapsed.
This diff is collapsed.
# Copyright (c) 2024, Tri Dao, Albert Gu.
from torch import nn
from torch.nn import functional as F
class GatedMLP(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
activation=F.silu,
bias=False,
multiple_of=128,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features if out_features is not None else in_features
hidden_features = (
hidden_features if hidden_features is not None else int(8 * in_features / 3)
)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x):
y = self.fc1(x)
y, gate = y.chunk(2, dim=-1)
y = y * self.activation(gate)
y = self.fc2(y)
return y
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import triton
import triton.language as tl
from packaging import version
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
if TRITON3:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
return dt
else:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
return dt
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment