Unverified Commit 76b3b20f authored by Kola's avatar Kola Committed by GitHub
Browse files

Update Mamba types and pass through use_cache attr to MambaModel (#29605)



* Update docstring for RMSNorm

* Update cache_params object to correct MambaCache type

* Update docstrings and type info

* Pass through use_cache

* ruff

* Reformat with 119 char limit per line (thanks Arthur)

* Pass through use_cache specifically to the backbone rather than all keyword arguments

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/mamba/modeling_mamba.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update tab

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 776c9d3a
......@@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -59,6 +59,24 @@ _CONFIG_FOR_DOC = "MambaConfig"
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [] # See all Mamba models at https://huggingface.co/models?filter=mamba
class MambaCache:
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.seqlen_offset = 0
self.dtype = dtype
intermediate_size = config.intermediate_size
ssm_state_size = config.state_size
conv_kernel_size = config.conv_kernel
self.conv_states = {
i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
class MambaMixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
......@@ -112,7 +130,7 @@ class MambaMixer(nn.Module):
" https://github.com/Dao-AILab/causal-conv1d"
)
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params=None):
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
......@@ -202,7 +220,7 @@ class MambaMixer(nn.Module):
return contextualized_states
# fmt: off
def slow_forward(self, input_states, cache_params=None):
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
......@@ -268,34 +286,16 @@ class MambaMixer(nn.Module):
return contextualized_states
# fmt: on
def forward(self, hidden_states, cache_params=None):
def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params)
return self.slow_forward(hidden_states, cache_params)
class MambaCache:
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.seqlen_offset = 0
self.dtype = dtype
intermediate_size = config.intermediate_size
ssm_state_size = config.state_size
conv_kernel_size = config.conv_kernel
self.conv_states = {
i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
class MambaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
MambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
......@@ -318,7 +318,7 @@ class MambaBlock(nn.Module):
self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = MambaMixer(config, layer_idx=layer_idx)
def forward(self, hidden_states, cache_params=None):
def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
......@@ -396,11 +396,11 @@ class MambaOutput(ModelOutput):
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
cache_params (`MambaCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
Includes both the State space model states weights after the selective scan, and the Convolutional states
Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
......@@ -408,8 +408,8 @@ class MambaOutput(ModelOutput):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
cache_params: Optional[List[torch.FloatTensor]] = None
last_hidden_state: Optional[torch.FloatTensor] = None
cache_params: Optional[MambaCache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
......@@ -423,9 +423,11 @@ class MambaCausalLMOutput(ModelOutput):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
cache_params (`MambaCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
......@@ -434,8 +436,8 @@ class MambaCausalLMOutput(ModelOutput):
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
cache_params: Optional[List[torch.FloatTensor]] = None
logits: Optional[torch.FloatTensor] = None
cache_params: Optional[MambaCache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
......@@ -516,7 +518,7 @@ class MambaModel(MambaPreTrainedModel):
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[List[torch.FloatTensor]] = None,
cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
......@@ -609,7 +611,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
return model_kwargs
def prepare_inputs_for_generation(
self, input_ids, cache_params=None, inputs_embeds=None, attention_mask=None, **kwargs
self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs
):
# only last token for inputs_ids if the state is passed along.
if cache_params is not None:
......@@ -633,10 +635,11 @@ class MambaForCausalLM(MambaPreTrainedModel):
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, MambaCausalLMOutput]:
r"""
......@@ -653,6 +656,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
hidden_states = mamba_outputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment